mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
[CK_TILE] not using structures under ck_tile/ops for ck_tile/host (#1834)
* not using structures under ck_tile/ops for ck_tile/host * update as constexpr function * Rename fn * Update other examples. --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com> Co-authored-by: Adam Osewski <Adam.Osewski@amd.com>
This commit is contained in:
@@ -3,6 +3,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
@@ -106,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using namespace ck_tile::literals;
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout));
|
||||
|
||||
auto f_host_tensor_descriptor = [](std::size_t batch_count_,
|
||||
std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
std::size_t batch_stride,
|
||||
auto layout) {
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, stride, 1_uz});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ck_tile::HostTensorDescriptor({batch_count_, row, col},
|
||||
{batch_stride, 1_uz, stride});
|
||||
}
|
||||
};
|
||||
|
||||
auto f_get_default_stride = [](std::size_t row,
|
||||
std::size_t col,
|
||||
std::size_t stride,
|
||||
auto layout) {
|
||||
if(stride == 0)
|
||||
{
|
||||
// give a chance if stride is zero, return a default packed stride
|
||||
if constexpr(std::is_same_v<decltype(layout), ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return col;
|
||||
}
|
||||
else
|
||||
{
|
||||
return row;
|
||||
}
|
||||
}
|
||||
else
|
||||
return stride;
|
||||
};
|
||||
|
||||
stride_A = f_get_default_stride(M, K, stride_A, a_layout);
|
||||
stride_B = f_get_default_stride(K, N, stride_B, b_layout);
|
||||
stride_C = f_get_default_stride(M, N, stride_C, c_layout);
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout));
|
||||
ck_tile::HostTensor<ADataType> a_m_k(ck_tile::host_tensor_descriptor(
|
||||
batch_count, M, K, stride_A, batch_stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(ck_tile::host_tensor_descriptor(
|
||||
batch_count, K, N, stride_B, batch_stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(ck_tile::host_tensor_descriptor(
|
||||
batch_count, M, N, stride_C, batch_stride_C, is_row_major(c_layout)));
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
@@ -191,8 +158,8 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(ck_tile::host_tensor_descriptor(
|
||||
batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){}));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
const auto b_n_k = b_k_n.transpose({0, 2, 1});
|
||||
@@ -216,8 +183,8 @@ int run_batched_gemm_example_with_layouts(int argc,
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{}));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(ck_tile::host_tensor_descriptor(
|
||||
batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){}));
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
Reference in New Issue
Block a user