mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK TILE] Implement cschuflle algorithm (#1842)
* [CK TILE] Implement cschuflle algorithm * Rebase * Vector store size fixes * fixes * Fixes * fixes * fmha fix * fixes * fixes of fixes
This commit is contained in:
@@ -1,194 +1,189 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
#define CK_TILE_MAX_RANK 5
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
|
||||
// memory.
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kTilePermute_,
|
||||
index_t kRank_,
|
||||
index_t kPerm0,
|
||||
index_t kPerm1,
|
||||
index_t TileSize0,
|
||||
index_t TileSize1,
|
||||
index_t kPerm2 = 0,
|
||||
index_t kPerm3 = 0,
|
||||
index_t kPerm4 = 0,
|
||||
index_t TileSize2 = 0,
|
||||
index_t TileSize3 = 0,
|
||||
index_t TileSize4 = 0>
|
||||
typename CLayout_,
|
||||
index_t kBlockSize_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
index_t kMWave_,
|
||||
index_t kNWave_,
|
||||
index_t kMPerXdl_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_>
|
||||
struct CShuffleEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kTilePermute = kTilePermute_;
|
||||
static constexpr index_t kRank = kRank_;
|
||||
static constexpr index_t kPerm[CK_TILE_MAX_RANK] = {kPerm0, kPerm1, kPerm2, kPerm3, kPerm4};
|
||||
static constexpr index_t tile_sizes[CK_TILE_MAX_RANK] = {
|
||||
TileSize0, TileSize1, TileSize2, TileSize3, TileSize4};
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kBlockSize = kBlockSize_;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMWave = kMWave_;
|
||||
static constexpr index_t kNWave = kNWave_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct CShuffleEpilogue
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
const index_t* kPerm = Problem::kPerm;
|
||||
static constexpr bool kTilePermute = Problem::kTilePermute;
|
||||
static constexpr index_t kRank = Problem::kRank;
|
||||
const index_t* tile_sizes = Problem::tile_sizes;
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
|
||||
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
|
||||
static constexpr index_t kMWave = Problem::kMWave;
|
||||
static constexpr index_t kNWave = Problem::kNWave;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
static constexpr index_t kMPerIteration = kMPerXdl * kMWave;
|
||||
static constexpr index_t kNPerIteration = kNPerXdl * kNWave;
|
||||
|
||||
// No additional shared memory needed
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
using WG = WarpGemmMfmaDispatcher<ODataType,
|
||||
ODataType,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed()
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
using CWarpTensor = typename WG::CWarpTensor;
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
// TODO: At now CShuffle doesn't allow to vector store after permute.
|
||||
// It should be fixed and this function should return true.
|
||||
return false;
|
||||
constexpr index_t MaxVectorStoreSize = 16;
|
||||
return MaxVectorStoreSize / sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename OAccTile>
|
||||
CK_TILE_DEVICE void permute_tile_data(OAccTile& o_acc_tile)
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeLdsBlockDescriptor()
|
||||
{
|
||||
using DataType = typename OAccTile::DataType;
|
||||
|
||||
// Get thread buffer
|
||||
auto& thread_buf = o_acc_tile.get_thread_buffer();
|
||||
|
||||
// Create a temporary buffer to hold the permuted data
|
||||
thread_buffer<DataType, OAccTile::kThreadElementSpaceSize> permuted_thread_buf;
|
||||
|
||||
// Get the lengths of each dimension
|
||||
auto thread_tensor_lengths = o_acc_tile.get_lengths();
|
||||
|
||||
// Total number of elements
|
||||
index_t total_elements = OAccTile::kThreadElementSpaceSize;
|
||||
|
||||
// Iterate over all elements
|
||||
for(index_t linear_idx = 0; linear_idx < total_elements; ++linear_idx)
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
// Convert linear index to multi-dimensional indices
|
||||
array<index_t, kRank> indices;
|
||||
index_t remaining = linear_idx;
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
indices(rev_i) = remaining % thread_tensor_lengths.get(number<rev_i>{});
|
||||
remaining /= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Apply the permutation
|
||||
array<index_t, kRank> permuted_indices;
|
||||
static_for<0, kRank, 1>{}(
|
||||
[&](auto i) { permuted_indices(i) = indices.get(number<Problem::kPerm[i]>{}); });
|
||||
|
||||
// Compute offsets
|
||||
index_t dst_offset = 0;
|
||||
index_t stride = 1;
|
||||
|
||||
static_for<0, kRank, 1>{}([&](auto i) {
|
||||
constexpr auto rev_i = kRank - 1 - i;
|
||||
dst_offset += permuted_indices[rev_i] * stride;
|
||||
stride *= thread_tensor_lengths.get(number<rev_i>{});
|
||||
});
|
||||
|
||||
// Move the data
|
||||
permuted_thread_buf(dst_offset) = thread_buf[linear_idx];
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<kNWave * kNPerXdl>{}, number<1>{}));
|
||||
}
|
||||
|
||||
// Copy the permuted data back to the original thread buffer
|
||||
for(index_t i = 0; i < total_elements; ++i)
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
thread_buf.set_as(i, permuted_thread_buf.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, OAccTile& o_acc_tile)
|
||||
{
|
||||
const auto& current_window_origin = o_dram_window_tmp.get_window_origin();
|
||||
|
||||
// Compute the tile coordinates by dividing the window origin by the tile sizes
|
||||
index_t tile_coords[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
tile_coords[i] = current_window_origin[i] / tile_sizes[i];
|
||||
// printf("The tile_coord is: %d", tile_coords[i]);
|
||||
}
|
||||
|
||||
// Apply the permutation to the tile coordinates
|
||||
index_t permuted_tile_coords[CK_TILE_MAX_RANK];
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_tile_coords[i] = tile_coords[kPerm[i]];
|
||||
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
|
||||
}
|
||||
|
||||
// Compute the permuted window origin
|
||||
index_t permuted_window_origin[CK_TILE_MAX_RANK] = {0};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
permuted_window_origin[i] = permuted_tile_coords[i] * tile_sizes[i];
|
||||
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
|
||||
}
|
||||
|
||||
typename ODramWindowTmp::BottomTensorIndex step = {};
|
||||
for(index_t i = 0; i < kRank; ++i)
|
||||
{
|
||||
step[i] = permuted_window_origin[i] - current_window_origin[i];
|
||||
}
|
||||
|
||||
// Move the window
|
||||
move_tile_window(o_dram_window_tmp, step);
|
||||
|
||||
// Permute the data within the tile if necessary
|
||||
if constexpr(kTilePermute)
|
||||
{
|
||||
permute_tile_data(o_acc_tile);
|
||||
}
|
||||
|
||||
// Store the tile data to the permuted location
|
||||
if constexpr(kPadM || kPadN)
|
||||
{
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
make_tuple(number<1>{}, number<kMWave * kMPerXdl>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return kMWave * kNWave * kMPerXdl * kNPerXdl * sizeof(ODataType);
|
||||
}
|
||||
|
||||
template <typename ODramWindow,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindow& out_dram_window, const OAccTile& o_acc_tile, void* p_smem)
|
||||
{
|
||||
|
||||
const index_t iMWarp = get_warp_id() / kNWave;
|
||||
const index_t iNWarp = get_warp_id() - iMWarp * kNWave;
|
||||
|
||||
constexpr auto lds_block_desc = MakeLdsBlockDescriptor<Problem>();
|
||||
auto o_lds_block = make_tensor_view<address_space_enum::lds>(
|
||||
static_cast<ODataType*>(p_smem), lds_block_desc);
|
||||
auto in_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMPerXdl>{}, number<kNPerXdl>{}),
|
||||
{number<kMPerXdl>{} * iMWarp, number<kNPerXdl>{} * iNWarp});
|
||||
auto out_lds_window =
|
||||
make_tile_window(o_lds_block,
|
||||
make_tuple(number<kMWave * kMPerXdl>{}, number<kNWave * kNPerXdl>{}),
|
||||
{0, 0});
|
||||
|
||||
using SFC = space_filling_curve<sequence<kMPerBlock, kNPerBlock>,
|
||||
sequence<0, 1>,
|
||||
sequence<kMPerXdl * kMWave, kNPerXdl * kNWave>>;
|
||||
constexpr index_t num_access = SFC::get_num_of_access();
|
||||
|
||||
using TileEncodingPattern =
|
||||
TileDistributionEncodingPattern2D<kBlockSize,
|
||||
kMPerIteration,
|
||||
kNPerIteration,
|
||||
GetVectorSizeC(),
|
||||
tile_distribution_pattern::thread_raked>;
|
||||
constexpr auto dram_tile_distribution = TileEncodingPattern::Make2DStaticTileDistribution();
|
||||
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
CWarpTensor c_warp_in_tensor;
|
||||
static_for<0, num_access, 1>{}([&](auto iAccess) {
|
||||
constexpr auto idx_y_start = SFC::get_index(iAccess);
|
||||
|
||||
constexpr auto mIter = number<idx_y_start.at(number<0>{}) / (kMPerXdl * kMWave)>{};
|
||||
constexpr auto nIter = number<idx_y_start.at(number<1>{}) / (kNPerXdl * kNWave)>{};
|
||||
|
||||
c_warp_in_tensor.get_thread_buffer() = o_acc_tile.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
const auto c_warp_in_tensor_casted = cast_tile<ODataType>(c_warp_in_tensor);
|
||||
|
||||
block_sync_lds();
|
||||
store_tile(in_lds_window, c_warp_in_tensor_casted);
|
||||
block_sync_lds();
|
||||
|
||||
const auto c_out_tensor =
|
||||
load_tile(make_tile_window(out_lds_window, dram_tile_distribution));
|
||||
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
store_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
update_tile(out_dram_window, c_out_tensor);
|
||||
}
|
||||
}
|
||||
if constexpr(iAccess != num_access - 1)
|
||||
{
|
||||
constexpr auto step = SFC::get_forward_step(iAccess);
|
||||
move_tile_window(out_dram_window, {step.at(number<0>{}), step.at(number<1>{})});
|
||||
}
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
};
|
||||
|
||||
template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename CLayout_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
index_t kMPerXdl_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool UseRawStore_ = true>
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
{
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct Default2DEpilogue
|
||||
{
|
||||
@@ -35,14 +57,13 @@ struct Default2DEpilogue
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
{
|
||||
|
||||
// TODO: this is ugly
|
||||
@@ -71,4 +92,76 @@ struct Default2DEpilogue
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
static constexpr index_t kKPerXdl = Problem::kKPerXdl;
|
||||
static constexpr index_t isCTransposed = Problem::isCTransposed;
|
||||
|
||||
using WG = WarpGemmMfmaDispatcher<ODataType,
|
||||
ODataType,
|
||||
AccDataType,
|
||||
kMPerXdl,
|
||||
kNPerXdl,
|
||||
kKPerXdl,
|
||||
isCTransposed>;
|
||||
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(isCTransposed)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// M dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user