mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit 'a1589a9667517ddc73048c05c6f3c859db99851d' into develop
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
template <typename ProblemType>
|
||||
bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
@@ -53,6 +54,17 @@ bool run_gemm(const ProblemType& problem_size, const ExecutionConfig& config)
|
||||
StrideB = f_get_default_stride(K, N, StrideB, BLayout{});
|
||||
StrideC = f_get_default_stride(M, N, StrideC, CLayout{});
|
||||
|
||||
try
|
||||
{
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Error: " << e.what() << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -29,10 +29,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
|
||||
@@ -44,7 +40,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::PersistentTileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
@@ -53,8 +48,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
|
||||
50
include/ck/library/utility/validation_common.hpp
Normal file
50
include/ck/library/utility/validation_common.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/utility/type.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
template <typename Layout>
|
||||
inline void
|
||||
validate_gemm_stride(int M, int N, int stride, const std::string& stride_name = "Stride")
|
||||
{
|
||||
if(ck::is_same_v<Layout, ck::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if(stride < M)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error: For ColumnMajor layout, " + stride_name + " (" + std::to_string(stride) +
|
||||
") must be greater than or equal to dim (" + std::to_string(M) + ")");
|
||||
}
|
||||
}
|
||||
else // RowMajor
|
||||
{
|
||||
if(stride < N)
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"Error: For RowMajor layout, " + stride_name + " (" + std::to_string(stride) +
|
||||
") must be greater than or equal to dim (" + std::to_string(N) + ")");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience functions for common GEMM patterns
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
inline void validate_gemm_strides_abc(int M, int N, int K, int StrideA, int StrideB, int StrideC)
|
||||
{
|
||||
validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
|
||||
validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
|
||||
validate_gemm_stride<CLayout>(M, N, StrideC, "StrideC");
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace ck
|
||||
@@ -155,7 +155,17 @@ struct GroupedGemmKernel
|
||||
return group_count * sizeof(GemmTransKernelArg);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() -> dim3 { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static auto BlockSize() -> dim3
|
||||
{
|
||||
if(is_wave32())
|
||||
{
|
||||
return dim3(kBlockSize / 2);
|
||||
}
|
||||
else
|
||||
{
|
||||
return dim3(kBlockSize);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum occupancy grid size for the persistent kernel on the current device.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -74,6 +75,10 @@ bool profile_gemm_ab_scale_impl(int do_verification,
|
||||
? ((K + ScaleBlockK - 1) / ScaleBlockK)
|
||||
: ((N + ScaleBlockN - 1) / ScaleBlockN);
|
||||
|
||||
ck::utils::validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
|
||||
ck::utils::validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
|
||||
ck::utils::validate_gemm_stride<BLayout>(M, N, StrideE, "StrideE");
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM,
|
||||
(K + ScaleBlockK - 1) / ScaleBlockK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -93,6 +94,9 @@ void profile_gemm_bias_add_reduce_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor2d(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor2d(K, N, StrideB, BLayout{}));
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -104,6 +105,10 @@ bool profile_gemm_blockscale_weighpreshuffle_impl(int do_verification,
|
||||
? ((K + ScaleBlockK - 1) / ScaleBlockK)
|
||||
: ((N + ScaleBlockN - 1) / ScaleBlockN);
|
||||
|
||||
ck::utils::validate_gemm_stride<ALayout>(M, K, StrideA, "StrideA");
|
||||
ck::utils::validate_gemm_stride<BLayout>(K, N, StrideB, "StrideB");
|
||||
ck::utils::validate_gemm_stride<BLayout>(M, N, StrideE, "StrideE");
|
||||
|
||||
Tensor<A0DataType> a0_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<A1DataType> a1_m_k(f_host_tensor_descriptor((M + ScaleBlockM - 1) / ScaleBlockM,
|
||||
(K + ScaleBlockK - 1) / ScaleBlockK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -24,6 +24,7 @@
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/utility/fill.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace profiler {
|
||||
@@ -64,6 +65,9 @@ int profile_gemm_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -88,6 +89,9 @@ bool profile_gemm_reduce_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -62,6 +63,9 @@ bool profile_gemm_splitk_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -59,6 +60,9 @@ bool profile_gemm_streamk_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -63,6 +64,9 @@ bool profile_gemm_universal_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -91,6 +92,9 @@ bool profile_gemm_universal_preshuffle_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<BDataType> b_k_n_permute(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -64,6 +65,9 @@ bool profile_gemm_universal_reduce_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
6
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Executable file → Normal file
6
profiler/include/profiler/profile_gemm_universal_streamk_impl.hpp
Executable file → Normal file
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
#include "ck/library/utility/host_tensor.hpp"
|
||||
#include "ck/library/utility/host_tensor_generator.hpp"
|
||||
#include "ck/library/utility/literals.hpp"
|
||||
#include "ck/library/utility/validation_common.hpp"
|
||||
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/reference_gemm.hpp"
|
||||
|
||||
@@ -67,6 +68,9 @@ bool profile_gemm_universal_streamk_impl(int do_verification,
|
||||
}
|
||||
};
|
||||
|
||||
ck::utils::validate_gemm_strides_abc<ALayout, BLayout, CLayout>(
|
||||
M, N, K, StrideA, StrideB, StrideC);
|
||||
|
||||
Tensor<ADataType> a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{}));
|
||||
Tensor<BDataType> b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{}));
|
||||
Tensor<CDataType> c_m_n_host_result(f_host_tensor_descriptor(M, N, StrideC, CLayout{}));
|
||||
|
||||
Reference in New Issue
Block a user