mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-07 00:04:37 +00:00
Merge remote-tracking branch 'upstream/develop' into ck_migraphx_integration
This commit is contained in:
@@ -406,7 +406,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static constexpr auto TailScheduler<1>()
|
||||
__device__ constexpr auto TailScheduler<1>()
|
||||
{
|
||||
// schedule
|
||||
constexpr auto num_ds_read_inst =
|
||||
@@ -433,7 +433,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ static constexpr auto TailScheduler<2>()
|
||||
__device__ constexpr auto TailScheduler<2>()
|
||||
{
|
||||
// schedule
|
||||
constexpr auto num_ds_read_inst =
|
||||
|
||||
@@ -324,55 +324,55 @@ struct DppSelector
|
||||
static constexpr auto GetDpp();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 32>()
|
||||
constexpr auto GetDpp<half_t, 8, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 8, 16>()
|
||||
constexpr auto GetDpp<half_t, 8, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_8x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 16, 16>()
|
||||
constexpr auto GetDpp<half_t, 16, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_16x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 32, 8>()
|
||||
constexpr auto GetDpp<half_t, 32, 8>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_32x8x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 1, 32>()
|
||||
constexpr auto GetDpp<half_t, 1, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_1x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 32>()
|
||||
constexpr auto GetDpp<half_t, 2, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x32x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 2, 16>()
|
||||
constexpr auto GetDpp<half_t, 2, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_2x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 16>()
|
||||
constexpr auto GetDpp<half_t, 4, 16>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x16x2;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetDpp<half_t, 4, 32>()
|
||||
constexpr auto GetDpp<half_t, 4, 32>()
|
||||
{
|
||||
return DppInstr::dpp8_f16_4x32x2;
|
||||
}
|
||||
|
||||
@@ -415,7 +415,7 @@ struct WmmaSelector
|
||||
static constexpr auto GetWmma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
constexpr auto GetWmma<half_t, half_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_f16_gfx12;
|
||||
@@ -425,7 +425,7 @@ struct WmmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
constexpr auto GetWmma<bhalf_t, bhalf_t, float, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_f32_16x16x16_bf16_gfx12;
|
||||
@@ -435,19 +435,19 @@ struct WmmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
|
||||
constexpr auto GetWmma<half_t, half_t, half_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_f16_16x16x16_f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
|
||||
constexpr auto GetWmma<bhalf_t, bhalf_t, bhalf_t, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_bf16_16x16x16_bf16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
constexpr auto GetWmma<int8_t, int8_t, int, 16, 16>()
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu8_gfx12;
|
||||
@@ -458,7 +458,7 @@ struct WmmaSelector
|
||||
|
||||
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
|
||||
template <>
|
||||
static constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
constexpr auto GetWmma<int4_t, int4_t, int, 16, 16>()
|
||||
{
|
||||
return WmmaInstr::wmma_i32_16x16x16_iu4;
|
||||
}
|
||||
|
||||
@@ -651,97 +651,97 @@ struct MfmaSelector
|
||||
static constexpr auto GetMfma();
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<double, 16, 16>()
|
||||
constexpr auto GetMfma<double, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f64_16x16x4f64;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 64, 64>()
|
||||
constexpr auto GetMfma<float, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 32, 64>()
|
||||
constexpr auto GetMfma<float, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 16, 64>()
|
||||
constexpr auto GetMfma<float, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 8, 64>()
|
||||
constexpr auto GetMfma<float, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 4, 64>()
|
||||
constexpr auto GetMfma<float, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x1xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 32, 32>()
|
||||
constexpr auto GetMfma<float, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x2xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<float, 16, 16>()
|
||||
constexpr auto GetMfma<float, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x4xf32;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 64, 64>()
|
||||
constexpr auto GetMfma<half_t, 64, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 32, 64>()
|
||||
constexpr auto GetMfma<half_t, 32, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 32, 32>()
|
||||
constexpr auto GetMfma<half_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x8f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 16, 16>()
|
||||
constexpr auto GetMfma<half_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x16f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 16, 64>()
|
||||
constexpr auto GetMfma<half_t, 16, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 8, 64>()
|
||||
constexpr auto GetMfma<half_t, 8, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<half_t, 4, 64>()
|
||||
constexpr auto GetMfma<half_t, 4, 64>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
@@ -751,7 +751,7 @@ struct MfmaSelector
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
{
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
@@ -762,72 +762,72 @@ struct MfmaSelector
|
||||
|
||||
#if defined(CK_USE_AMD_MFMA_GFX940)
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x16i8;
|
||||
}
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
}
|
||||
#else
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
#endif
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
constexpr auto GetMfma<f8_t, 32, 32, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
constexpr auto GetMfma<f8_t, 16, 16, bf8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32f8bf8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 32, 32, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_32x32x16bf8f8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
constexpr auto GetMfma<bf8_t, 16, 16, f8_t>()
|
||||
{
|
||||
return MfmaInstr::mfma_f32_16x16x32bf8f8;
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <initializer_list>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/numeric/integer.hpp"
|
||||
@@ -236,6 +237,16 @@ CK_TILE_HOST_DEVICE constexpr bool operator!=(const array<T, Size>& a, const arr
|
||||
return !(a == b);
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const std::vector<X>& x)
|
||||
{
|
||||
array<T, N> arr;
|
||||
|
||||
static_for<0, N, 1>{}([&x, &arr](auto i) { arr(i) = x[i]; });
|
||||
|
||||
return arr;
|
||||
}
|
||||
|
||||
template <typename T, index_t N, typename X>
|
||||
CK_TILE_HOST_DEVICE constexpr auto to_array(const X& x)
|
||||
{
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include "ck_tile/host/arg_parser.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/convolution_host_tensor_descriptor_helper.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/device_memory.hpp"
|
||||
#include "ck_tile/host/fill.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
namespace detail {
|
||||
|
||||
template <typename OldLayout>
|
||||
CK_TILE_HOST std::vector<std::size_t> get_layout_transpose_gnchw_to_old()
|
||||
{
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKW>)
|
||||
{
|
||||
return {0, 1, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNCDHW> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKCZYX> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
return {0, 1, 2, 3, 4, 5};
|
||||
}
|
||||
if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNWK>)
|
||||
{
|
||||
return {0, 1, 3, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNHWK>)
|
||||
{
|
||||
return {0, 1, 4, 2, 3};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GKZYXC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
return {0, 1, 5, 2, 3, 4};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NWGK>)
|
||||
{
|
||||
return {2, 0, 3, 1};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NHWGK>)
|
||||
{
|
||||
return {3, 0, 4, 1, 2};
|
||||
}
|
||||
else if constexpr(std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::KZYXGC> ||
|
||||
std::is_same_v<OldLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
return {4, 0, 5, 1, 2, 3};
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// make tensor descriptor for packed input tensor, and order the dimension in the order of GNCHW
|
||||
// regardless of physical layout
|
||||
template <typename InLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_input_host_tensor_descriptor_g_n_c_wis_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCHW> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNCDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNHWC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::GNDHWC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NHWGC> ||
|
||||
std::is_same_v<InLayout, ck_tile::tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.input_spatial_lengths_.begin(),
|
||||
param.input_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", InLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<InLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed weight tensor, and order the dimension in the order of GKCYX
|
||||
// regardless of physical layout
|
||||
template <typename WeiLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_weight_host_tensor_descriptor_g_k_c_xs_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXC>)
|
||||
{
|
||||
if(param.G_ != 1)
|
||||
{
|
||||
throw std::runtime_error("wrong! G != 1");
|
||||
}
|
||||
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCYX> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKCZYX>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKYXC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KYXGC> ||
|
||||
std::is_same_v<WeiLayout, ck_tile::tensor_layout::convolution::KZYXGC>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.K_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.C_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.filter_spatial_lengths_.begin(),
|
||||
param.filter_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", WeiLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<WeiLayout>());
|
||||
}
|
||||
|
||||
// make tensor descriptor for packed output tensor, and order the dimension in the order of GNKHW
|
||||
// regardless of physical layout
|
||||
template <typename OutLayout>
|
||||
CK_TILE_HOST HostTensorDescriptor
|
||||
make_output_host_tensor_descriptor_g_n_k_wos_packed(const ck_tile::conv::ConvParam& param)
|
||||
{
|
||||
std::vector<std::size_t> physical_lengths;
|
||||
|
||||
if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKHW> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNKDHW>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.end(),
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
// separate from legacy code above
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNHWK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::GNDHWK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 2,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else if constexpr(std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NHWGK> ||
|
||||
std::is_same_v<OutLayout, ck_tile::tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
physical_lengths = std::vector<std::size_t>{static_cast<std::size_t>(param.N_),
|
||||
static_cast<std::size_t>(param.G_),
|
||||
static_cast<std::size_t>(param.K_)};
|
||||
|
||||
physical_lengths.insert(physical_lengths.begin() + 1,
|
||||
param.output_spatial_lengths_.begin(),
|
||||
param.output_spatial_lengths_.begin() + param.num_dim_spatial_);
|
||||
}
|
||||
else
|
||||
{
|
||||
printf("%s\n", __func__);
|
||||
printf("%s\n", OutLayout::name);
|
||||
throw std::runtime_error("wrong! unsupported layout");
|
||||
}
|
||||
|
||||
return transpose_host_tensor_descriptor_given_new2old(
|
||||
HostTensorDescriptor(physical_lengths),
|
||||
detail::get_layout_transpose_gnchw_to_old<OutLayout>());
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
283
include/ck_tile/host/convolution_parameter.hpp
Normal file
283
include/ck_tile/host/convolution_parameter.hpp
Normal file
@@ -0,0 +1,283 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <numeric>
|
||||
#include <iterator>
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile {
|
||||
namespace conv {
|
||||
|
||||
struct ConvParam
|
||||
{
|
||||
ConvParam();
|
||||
ConvParam(ck_tile::index_t n_dim,
|
||||
ck_tile::index_t group_count,
|
||||
ck_tile::index_t n_batch,
|
||||
ck_tile::index_t n_out_channels,
|
||||
ck_tile::index_t n_in_channels,
|
||||
const std::vector<ck_tile::index_t>& filters_len,
|
||||
const std::vector<ck_tile::index_t>& input_len,
|
||||
const std::vector<ck_tile::index_t>& strides,
|
||||
const std::vector<ck_tile::index_t>& dilations,
|
||||
const std::vector<ck_tile::index_t>& left_pads,
|
||||
const std::vector<ck_tile::index_t>& right_pads)
|
||||
: num_dim_spatial_(static_cast<ck_tile::long_index_t>(n_dim)),
|
||||
G_(static_cast<ck_tile::long_index_t>(group_count)),
|
||||
N_(static_cast<ck_tile::long_index_t>(n_batch)),
|
||||
K_(static_cast<ck_tile::long_index_t>(n_out_channels)),
|
||||
C_(static_cast<ck_tile::long_index_t>(n_in_channels)),
|
||||
filter_spatial_lengths_(num_dim_spatial_),
|
||||
input_spatial_lengths_(num_dim_spatial_),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(num_dim_spatial_),
|
||||
conv_filter_dilations_(num_dim_spatial_),
|
||||
input_left_pads_(num_dim_spatial_),
|
||||
input_right_pads_(num_dim_spatial_)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
filter_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(filters_len[i]);
|
||||
input_spatial_lengths_[i] = static_cast<ck_tile::long_index_t>(input_len[i]);
|
||||
conv_filter_strides_[i] = static_cast<ck_tile::long_index_t>(strides[i]);
|
||||
conv_filter_dilations_[i] = static_cast<ck_tile::long_index_t>(dilations[i]);
|
||||
input_left_pads_[i] = static_cast<ck_tile::long_index_t>(left_pads[i]);
|
||||
input_right_pads_[i] = static_cast<ck_tile::long_index_t>(right_pads[i]);
|
||||
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ConvParam(ck_tile::long_index_t n_dim,
|
||||
ck_tile::long_index_t group_count,
|
||||
ck_tile::long_index_t n_batch,
|
||||
ck_tile::long_index_t n_out_channels,
|
||||
ck_tile::long_index_t n_in_channels,
|
||||
const std::vector<ck_tile::long_index_t>& filters_len,
|
||||
const std::vector<ck_tile::long_index_t>& input_len,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads,
|
||||
const std::vector<ck_tile::long_index_t>& right_pads)
|
||||
: num_dim_spatial_(n_dim),
|
||||
G_(group_count),
|
||||
N_(n_batch),
|
||||
K_(n_out_channels),
|
||||
C_(n_in_channels),
|
||||
filter_spatial_lengths_(filters_len),
|
||||
input_spatial_lengths_(input_len),
|
||||
output_spatial_lengths_(num_dim_spatial_),
|
||||
conv_filter_strides_(strides),
|
||||
conv_filter_dilations_(dilations),
|
||||
input_left_pads_(left_pads),
|
||||
input_right_pads_(right_pads)
|
||||
{
|
||||
if(static_cast<ck_tile::index_t>(filter_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_spatial_lengths_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_strides_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(conv_filter_dilations_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_left_pads_.size()) != num_dim_spatial_ ||
|
||||
static_cast<ck_tile::index_t>(input_right_pads_.size()) != num_dim_spatial_)
|
||||
{
|
||||
throw(std::runtime_error(
|
||||
"ConvParam::ConvParam: "
|
||||
"parameter size is different from number of declared dimensions!"));
|
||||
}
|
||||
|
||||
for(ck_tile::index_t i = 0; i < num_dim_spatial_; ++i)
|
||||
{
|
||||
// XEff = (X - 1) * conv_dilation_w + 1;
|
||||
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
const ck_tile::long_index_t x_eff =
|
||||
(filter_spatial_lengths_[i] - 1) * conv_filter_dilations_[i] + 1;
|
||||
|
||||
output_spatial_lengths_[i] =
|
||||
(input_spatial_lengths_[i] + input_left_pads_[i] + input_right_pads_[i] - x_eff) /
|
||||
conv_filter_strides_[i] +
|
||||
1;
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::long_index_t num_dim_spatial_;
|
||||
ck_tile::long_index_t G_;
|
||||
ck_tile::long_index_t N_;
|
||||
ck_tile::long_index_t K_;
|
||||
ck_tile::long_index_t C_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths_;
|
||||
std::vector<ck_tile::long_index_t> output_spatial_lengths_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides_;
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> input_left_pads_;
|
||||
std::vector<ck_tile::long_index_t> input_right_pads_;
|
||||
|
||||
std::vector<ck_tile::long_index_t> GetOutputSpatialLengths() const
|
||||
{
|
||||
return output_spatial_lengths_;
|
||||
}
|
||||
|
||||
std::size_t GetFlops() const
|
||||
{
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G_ * N_ * K_ * C_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::next(std::begin(output_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType>
|
||||
std::size_t GetInputByte() const
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) *
|
||||
(G_ * N_ * C_ *
|
||||
std::accumulate(std::begin(input_spatial_lengths_),
|
||||
std::next(std::begin(input_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType>
|
||||
std::size_t GetWeightByte() const
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) *
|
||||
(G_ * K_ * C_ *
|
||||
std::accumulate(std::begin(filter_spatial_lengths_),
|
||||
std::next(std::begin(filter_spatial_lengths_), num_dim_spatial_),
|
||||
1,
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType>
|
||||
std::size_t GetOutputByte() const
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G_ * N_ * K_ *
|
||||
std::accumulate(std::begin(output_spatial_lengths_),
|
||||
std::end(output_spatial_lengths_),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
std::size_t GetByte() const
|
||||
{
|
||||
return GetInputByte<InDataType>() + GetWeightByte<WeiDataType>() +
|
||||
GetOutputByte<OutDataType>();
|
||||
}
|
||||
};
|
||||
|
||||
ConvParam::ConvParam()
|
||||
: ConvParam::ConvParam(2, 1, 128, 256, 192, {3, 3}, {71, 71}, {2, 2}, {1, 1}, {1, 1}, {1, 1})
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST std::string get_conv_param_parser_helper_msg()
|
||||
{
|
||||
std::string msg;
|
||||
|
||||
msg += "Following arguments (depending on number of spatial dims):\n"
|
||||
" Number of spatial dimensions (1=Conv1d, 2=Conv2d, 3=Conv3d)\n"
|
||||
" G, N, K, C, \n"
|
||||
" <filter spatial dimensions>, (ie Y, X for 2D)\n"
|
||||
" <input image spatial dimensions>, (ie Hi, Wi for 2D)\n"
|
||||
" <strides>, (ie Sy, Sx for 2D)\n"
|
||||
" <dilations>, (ie Dy, Dx for 2D)\n"
|
||||
" <left padding>, (ie LeftPy, LeftPx for 2D)\n"
|
||||
" <right padding>, (ie RightPy, RightPx for 2D)\n";
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
||||
CK_TILE_HOST ck_tile::conv::ConvParam
|
||||
parse_conv_param(int num_dim_spatial, int arg_idx, char* const argv[])
|
||||
{
|
||||
const ck_tile::long_index_t G = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t N = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t K = std::stol(argv[arg_idx++]);
|
||||
const ck_tile::long_index_t C = std::stol(argv[arg_idx++]);
|
||||
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_spatial_lengths(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_strides(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> conv_filter_dilations(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_left_pads(num_dim_spatial);
|
||||
std::vector<ck_tile::long_index_t> input_right_pads(num_dim_spatial);
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
filter_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_spatial_lengths[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_strides[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
conv_filter_dilations[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_left_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
for(int i = 0; i < num_dim_spatial; ++i)
|
||||
{
|
||||
input_right_pads[i] = std::stol(argv[arg_idx++]);
|
||||
}
|
||||
|
||||
return ck_tile::conv::ConvParam{num_dim_spatial,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
filter_spatial_lengths,
|
||||
input_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
} // namespace conv
|
||||
} // namespace ck_tile
|
||||
@@ -176,7 +176,20 @@ struct HostTensorDescriptor
|
||||
return std::inner_product(iss.begin(), iss.end(), mStrides.begin(), std::size_t{0});
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc);
|
||||
friend std::ostream& operator<<(std::ostream& os, const HostTensorDescriptor& desc)
|
||||
{
|
||||
os << "dim " << desc.get_num_of_dimension() << ", ";
|
||||
|
||||
os << "lengths {";
|
||||
LogRange(os, desc.get_lengths(), ", ");
|
||||
os << "}, ";
|
||||
|
||||
os << "strides {";
|
||||
LogRange(os, desc.get_strides(), ", ");
|
||||
os << "}";
|
||||
|
||||
return os;
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<std::size_t> mLens;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -9,53 +9,125 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST void reference_im2col(HostTensor<T>& in_mtx_host_ref,
|
||||
const HostTensor<T>& in_host,
|
||||
int /*N*/,
|
||||
int /*K*/,
|
||||
int C,
|
||||
int /*Y*/,
|
||||
int X,
|
||||
int Hi,
|
||||
int Wi,
|
||||
int Ho,
|
||||
int Wo,
|
||||
int ConvStrideH,
|
||||
int ConvStrideW,
|
||||
int ConvDilationH,
|
||||
int ConvDilationW,
|
||||
int InLeftPadH,
|
||||
int InLeftPadW,
|
||||
int /*InRightPadH*/,
|
||||
int /*InRightPadW*/)
|
||||
template <typename InDataType, typename OutDataType, index_t NDimSpatial>
|
||||
CK_TILE_HOST void reference_im2col(const HostTensor<InDataType>& in_host,
|
||||
HostTensor<OutDataType>& out_host,
|
||||
const ck_tile::conv::ConvParam& conv_params)
|
||||
{
|
||||
int GemmM = in_mtx_host_ref.get_lengths()[0];
|
||||
int GemmK = in_mtx_host_ref.get_lengths()[1];
|
||||
const long_index_t G = in_host.get_lengths()[0];
|
||||
const long_index_t N = in_host.get_lengths()[1];
|
||||
const long_index_t C = in_host.get_lengths()[2];
|
||||
|
||||
for(int gemm_m = 0; gemm_m < GemmM; ++gemm_m)
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
int mtmp = gemm_m;
|
||||
int n = mtmp / (Ho * Wo);
|
||||
mtmp -= n * Ho * Wo;
|
||||
int ho = mtmp / Wo;
|
||||
int wo = mtmp - ho * Wo;
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[0];
|
||||
auto func = [&](auto g, auto n, auto wo) {
|
||||
long_index_t row = n * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(int gemm_k = 0; gemm_k < GemmK; ++gemm_k)
|
||||
{
|
||||
int ktmp = gemm_k;
|
||||
int y = ktmp / (X * C);
|
||||
ktmp -= y * X * C;
|
||||
int x = ktmp / C;
|
||||
int c = ktmp - x * C;
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[0]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
int hi = y * ConvDilationH + ho * ConvStrideH - InLeftPadH;
|
||||
int wi = x * ConvDilationW + wo * ConvStrideW - InLeftPadW;
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[3])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
bool inbound = (hi >= 0 && hi < Hi && wi >= 0 && wi < Wi);
|
||||
make_ParallelTensorFunctor(func, G, N, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[1];
|
||||
|
||||
in_mtx_host_ref(gemm_m, gemm_k) = inbound ? in_host(n, hi, wi, c) : 0;
|
||||
}
|
||||
auto func = [&](auto g, auto n, auto ho, auto wo) {
|
||||
long_index_t row = n * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[0]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[1]; ++x)
|
||||
{
|
||||
auto wi = static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
|
||||
if(hi >= 0 && type_convert<std::size_t>(hi) < in_host.get_lengths()[3] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[4])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
const long_index_t Do = conv_params.output_spatial_lengths_[0];
|
||||
const long_index_t Ho = conv_params.output_spatial_lengths_[1];
|
||||
const long_index_t Wo = conv_params.output_spatial_lengths_[2];
|
||||
|
||||
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
|
||||
long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
|
||||
long_index_t column = 0;
|
||||
|
||||
for(long_index_t z = 0; z < conv_params.filter_spatial_lengths_[0]; ++z)
|
||||
{
|
||||
auto di = static_cast<long_index_t>(d_o * conv_params.conv_filter_strides_[0]) +
|
||||
static_cast<long_index_t>(z * conv_params.conv_filter_dilations_[0]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[0]);
|
||||
for(long_index_t y = 0; y < conv_params.filter_spatial_lengths_[1]; ++y)
|
||||
{
|
||||
auto hi = static_cast<long_index_t>(ho * conv_params.conv_filter_strides_[1]) +
|
||||
static_cast<long_index_t>(y * conv_params.conv_filter_dilations_[1]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[1]);
|
||||
for(long_index_t x = 0; x < conv_params.filter_spatial_lengths_[2]; ++x)
|
||||
{
|
||||
auto wi =
|
||||
static_cast<long_index_t>(wo * conv_params.conv_filter_strides_[2]) +
|
||||
static_cast<long_index_t>(x * conv_params.conv_filter_dilations_[2]) -
|
||||
static_cast<long_index_t>(conv_params.input_left_pads_[2]);
|
||||
for(long_index_t c = 0; c < C; ++c)
|
||||
{
|
||||
if(di >= 0 &&
|
||||
type_convert<std::size_t>(di) < in_host.get_lengths()[3] &&
|
||||
hi >= 0 &&
|
||||
type_convert<std::size_t>(hi) < in_host.get_lengths()[4] &&
|
||||
wi >= 0 && type_convert<std::size_t>(wi) < in_host.get_lengths()[5])
|
||||
{
|
||||
InDataType v_in = in_host(g, n, c, di, hi, wi);
|
||||
out_host(g, row, column) = type_convert<OutDataType>(v_in);
|
||||
}
|
||||
column++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(func, G, N, Do, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
|
||||
{
|
||||
auto [origin_start, origin_end] = GetTileRangeAlongX(i_y, height, width);
|
||||
|
||||
const index_t x_per_split = ck_tile::max(1, x_total / num_splits);
|
||||
const index_t x_per_split = ck_tile::max(1, integer_divide_ceil(x_total, num_splits));
|
||||
const index_t split_start = x_per_split * i_split;
|
||||
const index_t split_end = (i_split == num_splits - 1 ? x_total : split_start + x_per_split);
|
||||
const index_t split_end = split_start + x_per_split;
|
||||
|
||||
return ck_tile::make_tuple(ck_tile::max(origin_start, split_start),
|
||||
ck_tile::min(origin_end, split_end));
|
||||
|
||||
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* o_ptr;
|
||||
|
||||
ck_tile::index_t batch;
|
||||
ck_tile::index_t max_seqlen_q;
|
||||
|
||||
ck_tile::index_t seqlen_q;
|
||||
ck_tile::index_t hdim_v;
|
||||
ck_tile::index_t num_splits;
|
||||
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
ck_tile::index_t nhead_stride_o;
|
||||
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
std::conditional_t<kStoreLSE, CommonLSEKargs, EmptyKargs<0>>,
|
||||
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
|
||||
{
|
||||
ck_tile::index_t batch_stride_o;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
ck_tile::index_t batch_stride_o;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
{}, // placeholder for fp8_static_quant args
|
||||
batch_stride_o,
|
||||
batch_stride_lse_acc};
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
batch_stride_o};
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
void* lse_ptr,
|
||||
void* o_ptr,
|
||||
ck_tile::index_t batch,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
const void* seqstart_q_ptr,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits,
|
||||
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t nhead_stride_lse,
|
||||
ck_tile::index_t nhead_stride_o,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc)
|
||||
{
|
||||
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_ptr,
|
||||
o_ptr,
|
||||
batch,
|
||||
max_seqlen_q,
|
||||
-1, // seqlen will be updated by another pointer
|
||||
hdim_v,
|
||||
num_splits,
|
||||
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
nhead_stride_o,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for lse
|
||||
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
return kargs;
|
||||
}
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_, hdim_v_);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
|
||||
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
|
||||
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
long_index_t batch_offset_lse = 0;
|
||||
long_index_t batch_offset_o = 0;
|
||||
|
||||
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
// get starting offset for each batch
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.row_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = query_start;
|
||||
}
|
||||
|
||||
batch_offset_o = query_start * kargs.row_stride_o;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
|
||||
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
|
||||
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
}
|
||||
else
|
||||
{
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(kStoreLSE)
|
||||
{
|
||||
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
|
||||
}
|
||||
|
||||
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
auto o_acc_dram = [&]() {
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.num_splits, kargs.max_seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.num_splits, kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.split_stride_o_acc, kargs.row_stride_o_acc, 1),
|
||||
number<FmhaPipeline::kAlignmentOacc>{},
|
||||
number<1>{});
|
||||
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
make_tuple(number<1>{}, number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN1>{}),
|
||||
sequence<false, kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
const index_t padded_max_seqlen_q =
|
||||
const index_t padded_seqlen_q =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<1>{}];
|
||||
const index_t padded_hdim_v =
|
||||
o_acc_dram_view.get_tensor_descriptor().get_lengths()[number<2>{}];
|
||||
|
||||
return transform_tensor_view(
|
||||
o_acc_dram_view,
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_max_seqlen_q)),
|
||||
make_tuple(make_merge_transform(make_tuple(kargs.num_splits, padded_seqlen_q)),
|
||||
make_pass_through_transform(padded_hdim_v)),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
identity{}, // lse_element_func
|
||||
composes(saturates<fp8_t>{}, scales{kargs.scale_o}), // o_acc_element_func
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
else
|
||||
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
|
||||
o_acc_dram_window,
|
||||
lse_dram_window,
|
||||
kargs.num_splits,
|
||||
kargs.max_seqlen_q,
|
||||
kargs.seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
|
||||
static constexpr ck_tile::index_t kM0 = kM0_;
|
||||
static constexpr ck_tile::index_t kN1 = kN1_;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size_,
|
||||
ck_tile::index_t nhead_,
|
||||
ck_tile::index_t seqlen_q_,
|
||||
ck_tile::index_t hdim_v_)
|
||||
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v_, kN1),
|
||||
nhead_,
|
||||
batch_size_);
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead,
|
||||
batch_size);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/, ck_tile::index_t hdim_v)
|
||||
{
|
||||
// const index_t num_tile_m0 = seqlen_q / kM0;
|
||||
const index_t num_tile_n1 = ck_tile::integer_divide_ceil(hdim_v, kN1);
|
||||
|
||||
const index_t i_block = blockIdx.x;
|
||||
|
||||
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_lse_acc;
|
||||
ck_tile::index_t nhead_stride_o_acc;
|
||||
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
|
||||
ck_tile::index_t split_stride_lse_acc;
|
||||
ck_tile::index_t split_stride_o_acc;
|
||||
};
|
||||
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t batch_stride_q;
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_lse_acc;
|
||||
ck_tile::index_t batch_stride_o_acc;
|
||||
};
|
||||
|
||||
struct GroupModeKargs
|
||||
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const int32_t* seqstart_k_ptr;
|
||||
const int32_t* seqlen_k_ptr;
|
||||
|
||||
ck_tile::index_t batch_stride_k;
|
||||
ck_tile::index_t batch_stride_v;
|
||||
ck_tile::index_t batch_stride_k; // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v; // only used for paged-kvcache
|
||||
};
|
||||
|
||||
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
|
||||
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
|
||||
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
|
||||
batch_stride_q,
|
||||
batch_stride_k,
|
||||
batch_stride_v};
|
||||
batch_stride_v,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc};
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
|
||||
ck_tile::index_t nhead_stride_bias,
|
||||
ck_tile::index_t nhead_stride_lse_acc,
|
||||
ck_tile::index_t nhead_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k,
|
||||
ck_tile::index_t batch_stride_v,
|
||||
ck_tile::index_t batch_stride_lse_acc,
|
||||
ck_tile::index_t batch_stride_o_acc,
|
||||
ck_tile::index_t batch_stride_k, // only used for paged-kvcache
|
||||
ck_tile::index_t batch_stride_v, // only used for paged-kvcache
|
||||
ck_tile::index_t split_stride_lse_acc,
|
||||
ck_tile::index_t split_stride_o_acc,
|
||||
ck_tile::index_t window_size_left,
|
||||
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
|
||||
nhead_stride_v,
|
||||
nhead_stride_lse_acc,
|
||||
nhead_stride_o_acc,
|
||||
batch_stride_lse_acc,
|
||||
batch_stride_o_acc,
|
||||
split_stride_lse_acc,
|
||||
split_stride_o_acc}, // args for common karg
|
||||
{}, // placeholder for bias
|
||||
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, hdim_v, num_splits);
|
||||
return TilePartitioner::GridSize(batch_size, nhead, max_seqlen_q, hdim_v, num_splits);
|
||||
}
|
||||
|
||||
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
|
||||
long_index_t batch_offset_v = 0;
|
||||
long_index_t batch_offset_bias = 0;
|
||||
long_index_t batch_offset_lse_acc = 0;
|
||||
const long_index_t batch_offset_o_acc =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
long_index_t batch_offset_o_acc = 0;
|
||||
|
||||
if constexpr(kIsGroupMode)
|
||||
{
|
||||
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
|
||||
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
|
||||
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
|
||||
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_q = query_start * kargs.stride_q;
|
||||
batch_offset_k = key_start * kargs.stride_k;
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
batch_offset_v = key_start * kargs.stride_v;
|
||||
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_bias = query_start * kargs.stride_bias + key_start;
|
||||
}
|
||||
|
||||
batch_offset_lse_acc = query_start;
|
||||
batch_offset_o_acc = query_start * kargs.stride_o_acc;
|
||||
|
||||
// get real # queries & # keys under group mode
|
||||
kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
|
||||
|
||||
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
|
||||
batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
|
||||
batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
|
||||
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
|
||||
batch_offset_o_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
|
||||
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
|
||||
{
|
||||
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
|
||||
const auto o_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
o_acc_ptr,
|
||||
make_tuple(kargs.seqlen_q, kargs.hdim_v),
|
||||
make_tuple(kargs.hdim_v, 1),
|
||||
number<FmhaPipeline::kAlignmentO>{},
|
||||
make_tuple(kargs.stride_o_acc, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
|
||||
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
|
||||
|
||||
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
|
||||
ck_tile::index_t nhead,
|
||||
ck_tile::index_t seqlen_q,
|
||||
ck_tile::index_t max_seqlen_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_splits)
|
||||
{
|
||||
// TODO: this may need tuning
|
||||
return dim3(ck_tile::integer_divide_ceil(seqlen_q, kM0) *
|
||||
return dim3(ck_tile::integer_divide_ceil(max_seqlen_q, kM0) *
|
||||
ck_tile::integer_divide_ceil(hdim_v, kN1),
|
||||
nhead * num_splits,
|
||||
batch_size);
|
||||
|
||||
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
},
|
||||
s_acc,
|
||||
bias_s_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
|
||||
{
|
||||
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 4, OGrad@V Gemm2
|
||||
auto dp_acc = SPGradBlockTileType{};
|
||||
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<2>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// STAGE 5, P^T(PGrad^T - D)
|
||||
auto ds = SPGradBlockTileType{};
|
||||
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
Policy::template MakeBiasTileDistribution<Problem>());
|
||||
shuffle_tile(dbias_tile, shuffled_dbias_tile);
|
||||
store_tile(dbias_dram_window, dbias_tile);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// STAGE 6, SGrad^T@Q^T Gemm3
|
||||
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<3>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// STAGE 7, SGrad@K^T Gemm4
|
||||
auto dq_acc = QGradBlockTileType{};
|
||||
clear_tile(dq_acc);
|
||||
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
});
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<4>();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Results Scale
|
||||
if constexpr(FmhaDropout::IsDropout)
|
||||
|
||||
@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<0>()
|
||||
{
|
||||
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
|
||||
// Comp: Q x K
|
||||
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<1>()
|
||||
{
|
||||
// Mem: Q^T LDS load
|
||||
// Comp: OGrad x V
|
||||
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<2>()
|
||||
{
|
||||
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
|
||||
// Comp: PT x OGrad
|
||||
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<3>()
|
||||
{
|
||||
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
|
||||
// Comp: SGradT x QT
|
||||
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
template <>
|
||||
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
|
||||
CK_TILE_DEVICE constexpr void GemmStagedScheduler<4>()
|
||||
{
|
||||
// Mem: SGrad, OGrad, D LDS load.
|
||||
// Comp: SGrad x KT
|
||||
|
||||
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const LSEElementFunction& lse_element_func,
|
||||
const OaccElementFunction& o_acc_element_func,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
// lse_acc tile in LDS
|
||||
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
auto o_acc = make_static_distributed_tensor<OaccDataType>(o_acc_dist);
|
||||
clear_tile(o_acc);
|
||||
|
||||
const index_t padded_max_seqlen_q = integer_divide_ceil(max_seqlen_q, kM0) * kM0;
|
||||
const index_t padded_seqlen_q = integer_divide_ceil(seqlen_q, kM0) * kM0;
|
||||
|
||||
for(index_t i_split = 0; i_split < num_splits; ++i_split)
|
||||
{
|
||||
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
});
|
||||
}
|
||||
|
||||
move_tile_window(o_acc_dram_window, {padded_max_seqlen_q, 0});
|
||||
move_tile_window(o_acc_dram_window, {padded_seqlen_q, 0});
|
||||
}
|
||||
|
||||
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
|
||||
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
const OaccDramBlockWindow& o_acc_dram_block_window,
|
||||
LSEDramBlockWindow& lse_dram_block_window,
|
||||
index_t num_splits,
|
||||
index_t max_seqlen_q,
|
||||
index_t seqlen_q,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
return operator()(lse_acc_dram_block_window,
|
||||
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
|
||||
identity{},
|
||||
identity{},
|
||||
num_splits,
|
||||
max_seqlen_q,
|
||||
seqlen_q,
|
||||
smem_ptr);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
|
||||
}();
|
||||
|
||||
static constexpr index_t kAlignmentO =
|
||||
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
|
||||
static constexpr index_t kAlignmentBias =
|
||||
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
|
||||
|
||||
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
|
||||
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
|
||||
|
||||
// check early exit if masked and no work to do.
|
||||
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
|
||||
// check early exit if no work to do
|
||||
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
|
||||
{
|
||||
const index_t original_num_total_loop =
|
||||
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
|
||||
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
|
||||
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
|
||||
constexpr auto i_idx = make_tuple(idx0);
|
||||
const auto tmp = [&]() {
|
||||
if constexpr(FmhaMask::IsMasking)
|
||||
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
|
||||
FmhaMask::IsMasking)
|
||||
{
|
||||
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
|
||||
}
|
||||
|
||||
9
include/ck_tile/ops/image_to_column.hpp
Normal file
9
include/ck_tile/ops/image_to_column.hpp
Normal file
@@ -0,0 +1,9 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
|
||||
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
@@ -0,0 +1,224 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_>
|
||||
struct ImageToColumn
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
static constexpr auto I3 = number<3>{};
|
||||
static constexpr auto I4 = number<4>{};
|
||||
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
using InDataType = remove_cvref_t<typename Problem::InDataType>;
|
||||
using OutDataType = remove_cvref_t<typename Problem::OutDataType>;
|
||||
|
||||
static constexpr index_t NDimSpatial = Problem::NDimSpatial;
|
||||
|
||||
static constexpr index_t AligmentIn = Problem::AligmentIn;
|
||||
static constexpr index_t AligmentOut = Problem::AligmentOut;
|
||||
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
|
||||
static constexpr index_t kMPerBlock = Problem::BlockShape::kMPerBlock;
|
||||
static constexpr index_t kKPerBlock = Problem::BlockShape::kKPerBlock;
|
||||
|
||||
struct Kargs
|
||||
{
|
||||
const void* p_in;
|
||||
void* p_out;
|
||||
|
||||
const long_index_t G;
|
||||
const long_index_t N;
|
||||
const long_index_t C;
|
||||
|
||||
const array<long_index_t, NDimSpatial> input_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial> filter_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial> output_spatial_lengths;
|
||||
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides;
|
||||
const array<long_index_t, 3> gemm_g_m_k_strides;
|
||||
const array<long_index_t, NDimSpatial> conv_filter_strides;
|
||||
const array<long_index_t, NDimSpatial> conv_filter_dilations;
|
||||
const array<long_index_t, NDimSpatial> input_left_pads;
|
||||
const array<long_index_t, NDimSpatial> input_right_pads;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static constexpr Kargs
|
||||
MakeKargs(const void* p_in,
|
||||
void* p_out,
|
||||
const long_index_t G,
|
||||
const long_index_t N,
|
||||
const long_index_t C,
|
||||
const array<long_index_t, NDimSpatial> input_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial> filter_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial> output_spatial_lengths,
|
||||
const array<long_index_t, NDimSpatial + 3> image_g_n_c_wis_strides,
|
||||
const array<long_index_t, 3> gemm_g_m_k_strides,
|
||||
const array<long_index_t, NDimSpatial> conv_filter_strides,
|
||||
const array<long_index_t, NDimSpatial> conv_filter_dilations,
|
||||
const array<long_index_t, NDimSpatial> input_left_pads,
|
||||
const array<long_index_t, NDimSpatial> input_right_pads)
|
||||
{
|
||||
return Kargs{p_in,
|
||||
p_out,
|
||||
G,
|
||||
N,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
image_g_n_c_wis_strides,
|
||||
gemm_g_m_k_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads};
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(index_t GemmM, index_t GemmK, index_t Batch)
|
||||
{
|
||||
return dim3(
|
||||
integer_divide_ceil(GemmM, kMPerBlock), integer_divide_ceil(GemmK, kKPerBlock), Batch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE auto MakeImageMKDesc(const Kargs& kargs) const
|
||||
{
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
|
||||
const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
kargs.N, kargs.input_spatial_lengths[I0], kargs.input_spatial_lengths[I1], kargs.C),
|
||||
make_tuple(kargs.image_g_n_c_wis_strides[I1],
|
||||
kargs.image_g_n_c_wis_strides[I3],
|
||||
kargs.image_g_n_c_wis_strides[I4],
|
||||
kargs.image_g_n_c_wis_strides[I2]),
|
||||
number<AligmentIn>{},
|
||||
I1);
|
||||
|
||||
const auto in_n_hip_wip_c_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_desc,
|
||||
make_tuple(make_pass_through_transform(kargs.N),
|
||||
make_pad_transform(kargs.input_spatial_lengths[I0],
|
||||
kargs.input_left_pads[I0],
|
||||
kargs.input_right_pads[I0]),
|
||||
make_pad_transform(kargs.input_spatial_lengths[I1],
|
||||
kargs.input_left_pads[I1],
|
||||
kargs.input_right_pads[I1]),
|
||||
make_pass_through_transform(kargs.C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(kargs.N),
|
||||
make_embed_transform(
|
||||
make_tuple(kargs.filter_spatial_lengths[I0], kargs.output_spatial_lengths[I0]),
|
||||
make_tuple(kargs.conv_filter_dilations[I0], kargs.conv_filter_strides[I0])),
|
||||
make_embed_transform(
|
||||
make_tuple(kargs.filter_spatial_lengths[I1], kargs.output_spatial_lengths[I1]),
|
||||
make_tuple(kargs.conv_filter_dilations[I1], kargs.conv_filter_strides[I1])),
|
||||
make_pass_through_transform(kargs.C)),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}, sequence<3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3, 4>{}, sequence<5>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
in_n_y_ho_x_wo_c_desc,
|
||||
make_tuple(
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.N, kargs.output_spatial_lengths[I0], kargs.output_spatial_lengths[I1])),
|
||||
make_merge_transform(make_tuple(
|
||||
kargs.filter_spatial_lengths[I0], kargs.filter_spatial_lengths[I1], kargs.C))),
|
||||
make_tuple(sequence<0, 2, 4>{}, sequence<1, 3, 5>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE auto CalculateMKDims(const Kargs& kargs) const
|
||||
{
|
||||
static_assert(NDimSpatial == 2, "Not supported.");
|
||||
const index_t M = kargs.N * static_cast<index_t>(kargs.output_spatial_lengths[I0] *
|
||||
kargs.output_spatial_lengths[I1]);
|
||||
const index_t K = kargs.C * static_cast<index_t>(kargs.filter_spatial_lengths[I0] *
|
||||
kargs.filter_spatial_lengths[I1]);
|
||||
return make_tuple(M, K);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBlockTileDistribution()
|
||||
{
|
||||
using P = typename Problem::BlockShape;
|
||||
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
|
||||
// Y: {kMPerThread, kKPerThread}
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<P::kMWarpPerBlock, P::kMThreadPerWarp, P::kMPerThread>,
|
||||
sequence<P::kKWarpPerBlock, P::kKThreadPerWarp, P::kKPerThread>>,
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 0>, sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<2, 2>>{});
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void ConvTensorRearrange(const Kargs& kargs) const
|
||||
{
|
||||
const auto [M, K] = CalculateMKDims(kargs);
|
||||
|
||||
const index_t iM = __builtin_amdgcn_readfirstlane(blockIdx.x * kMPerBlock);
|
||||
const index_t iK = __builtin_amdgcn_readfirstlane(blockIdx.y * kKPerBlock);
|
||||
const index_t iBatch = __builtin_amdgcn_readfirstlane(blockIdx.z);
|
||||
|
||||
const auto in_offset = iBatch * kargs.image_g_n_c_wis_strides[I0];
|
||||
const auto out_offset = iBatch * kargs.gemm_g_m_k_strides[I0];
|
||||
|
||||
const auto image_m_k = make_tensor_view<address_space_enum::global>(
|
||||
static_cast<const InDataType*>(kargs.p_in) + in_offset, MakeImageMKDesc(kargs));
|
||||
const auto gemm_m_k = make_naive_tensor_view<address_space_enum::global>(
|
||||
static_cast<OutDataType*>(kargs.p_out) + out_offset,
|
||||
make_tuple(M, K),
|
||||
make_tuple(kargs.gemm_g_m_k_strides[I1], kargs.gemm_g_m_k_strides[I2]),
|
||||
number<AligmentOut>{},
|
||||
I1);
|
||||
|
||||
const auto image_m_k_padded =
|
||||
pad_tensor_view(image_m_k,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
sequence<false, true>{});
|
||||
const auto gemm_m_k_padded =
|
||||
pad_tensor_view(gemm_m_k,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
sequence<false, true>{});
|
||||
|
||||
constexpr auto dstr = MakeBlockTileDistribution();
|
||||
|
||||
const auto image_tile =
|
||||
make_tile_window(image_m_k_padded,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{iM, iK},
|
||||
dstr);
|
||||
|
||||
auto gemm_tile = make_tile_window(gemm_m_k_padded,
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
{iM, iK},
|
||||
dstr);
|
||||
|
||||
// load from Global
|
||||
const auto loaded_tile = load_tile(image_tile);
|
||||
// save to Global
|
||||
store_tile(gemm_tile, loaded_tile);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs& kargs) const { ConvTensorRearrange(kargs); }
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,27 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType_,
|
||||
typename OutDataType_,
|
||||
typename BlockShape_,
|
||||
index_t NDimSpatial_,
|
||||
index_t AligmentIn_,
|
||||
index_t AligmentOut_>
|
||||
struct BlockImageToColumnProblem
|
||||
{
|
||||
using InDataType = remove_cvref_t<InDataType_>;
|
||||
using OutDataType = remove_cvref_t<OutDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr index_t NDimSpatial = NDimSpatial_;
|
||||
static constexpr index_t AligmentIn = AligmentIn_;
|
||||
static constexpr index_t AligmentOut = AligmentOut_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
template <typename ThreadTile, // Sequence<...
|
||||
typename WarpTile, // Sequence<...
|
||||
typename BlockTile> // Sequence<...
|
||||
struct TileImageToColumnShape
|
||||
{
|
||||
static constexpr index_t kMPerThread = ThreadTile::at(number<0>{});
|
||||
static constexpr index_t kKPerThread = ThreadTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMPerWarp = WarpTile::at(number<0>{});
|
||||
static constexpr index_t kKPerWarp = WarpTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMThreadPerWarp = kMPerWarp / kMPerThread;
|
||||
static constexpr index_t kKThreadPerWarp = kKPerWarp / kKPerThread;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kKPerBlock = BlockTile::at(number<1>{});
|
||||
|
||||
static constexpr index_t kMWarpPerBlock = kMPerBlock / kMPerWarp;
|
||||
static constexpr index_t kKWarpPerBlock = kKPerBlock / kKPerWarp;
|
||||
|
||||
static constexpr index_t kBlockSize = warpSize * kMWarpPerBlock * kKWarpPerBlock;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user