diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index e8fa102643..e29ba272f5 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -2,6 +2,13 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -88,48 +95,16 @@ int run_gemm_example_with_layouts(int argc, int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); - using namespace ck_tile::literals; + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); - auto f_host_tensor_descriptor = - [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({row, col}, {stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({row, col}, {1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, a_layout); - stride_B = f_get_default_stride(K, N, stride_B, b_layout); - stride_C = f_get_default_stride(M, N, stride_C, CLayout{}); - - ck_tile::HostTensor a_m_k(f_host_tensor_descriptor(M, K, stride_A, a_layout)); - ck_tile::HostTensor b_k_n(f_host_tensor_descriptor(K, N, stride_B, b_layout)); + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); ck_tile::HostTensor c_m_n_dev_result( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); // TODO: add different init types ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); @@ -163,7 +138,7 @@ int run_gemm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( @@ -185,7 +160,7 @@ int run_gemm_example_with_layouts(int argc, else if(arg_parser.get_int("v") == 2) { ck_tile::HostTensor c_m_n_gpu_ref( - f_host_tensor_descriptor(M, N, stride_C, CLayout{})); + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); diff --git a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc index 2fe81e87c4..c3ed76f5ef 100644 --- a/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc +++ b/example/ck_tile/16_batched_gemm/run_batched_gemm_example.inc @@ -3,6 +3,13 @@ #pragma once +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -106,56 +113,16 @@ int run_batched_gemm_example_with_layouts(int argc, int n_warmup = arg_parser.get_int("warmup"); int n_repeat = arg_parser.get_int("repeat"); - using namespace ck_tile::literals; + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(c_layout)); - auto f_host_tensor_descriptor = [](std::size_t batch_count_, - std::size_t row, - std::size_t col, - std::size_t stride, - std::size_t batch_stride, - auto layout) { - if constexpr(std::is_same_v) - { - return ck_tile::HostTensorDescriptor({batch_count_, row, col}, - {batch_stride, stride, 1_uz}); - } - else - { - return ck_tile::HostTensorDescriptor({batch_count_, row, col}, - {batch_stride, 1_uz, stride}); - } - }; - - auto f_get_default_stride = [](std::size_t row, - std::size_t col, - std::size_t stride, - auto layout) { - if(stride == 0) - { - // give a chance if stride is zero, return a default packed stride - if constexpr(std::is_same_v) - { - return col; - } - else - { - return row; - } - } - else - return stride; - }; - - stride_A = f_get_default_stride(M, K, stride_A, a_layout); - stride_B = f_get_default_stride(K, N, stride_B, b_layout); - stride_C = f_get_default_stride(M, N, stride_C, c_layout); - - ck_tile::HostTensor a_m_k( - f_host_tensor_descriptor(batch_count, M, K, stride_A, batch_stride_A, a_layout)); - ck_tile::HostTensor b_k_n( - f_host_tensor_descriptor(batch_count, K, N, stride_B, batch_stride_B, b_layout)); - ck_tile::HostTensor c_m_n_dev_result( - f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, c_layout)); + ck_tile::HostTensor a_m_k(ck_tile::host_tensor_descriptor( + batch_count, M, K, stride_A, batch_stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n(ck_tile::host_tensor_descriptor( + batch_count, K, N, stride_B, batch_stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result(ck_tile::host_tensor_descriptor( + batch_count, M, N, stride_C, batch_stride_C, is_row_major(c_layout))); ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); @@ -191,8 +158,8 @@ int run_batched_gemm_example_with_layouts(int argc, if(arg_parser.get_int("v") == 1) { - ck_tile::HostTensor c_m_n_host_ref( - f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); c_m_n_host_ref.SetZero(); const auto b_n_k = b_k_n.transpose({0, 2, 1}); @@ -216,8 +183,8 @@ int run_batched_gemm_example_with_layouts(int argc, } else if(arg_parser.get_int("v") == 2) { - ck_tile::HostTensor c_m_n_gpu_ref( - f_host_tensor_descriptor(batch_count, M, N, stride_C, batch_stride_C, CLayout{})); + ck_tile::HostTensor c_m_n_gpu_ref(ck_tile::host_tensor_descriptor( + batch_count, M, N, stride_C, batch_stride_C, is_row_major(CLayout){})); ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); c_m_n_gpu_ref.SetZero(); c_m_n_gpu_buf_ref.SetZero(); diff --git a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc index bd7ce38007..34b6ee666c 100644 --- a/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc +++ b/example/ck_tile/17_grouped_gemm/run_grouped_gemm_example.inc @@ -3,6 +3,13 @@ #pragma once +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + auto calculate_rtol_atol(const ck_tile::index_t K, const ck_tile::index_t kbatch, const float max_accumulated_value) @@ -128,16 +135,19 @@ int run_grouped_gemm_example_with_layouts(int argc, const ck_tile::index_t N = Ns[i]; const ck_tile::index_t K = Ks[i]; - stride_As[i] = ck_tile::get_default_stride(M, N, stride_As[i], a_layout); - stride_Bs[i] = ck_tile::get_default_stride(K, N, stride_Bs[i], b_layout); - stride_Cs[i] = ck_tile::get_default_stride(M, N, stride_Cs[i], CLayout{}); + stride_As[i] = + ck_tile::get_default_stride(M, N, stride_As[i], is_row_major(a_layout)); + stride_Bs[i] = + ck_tile::get_default_stride(K, N, stride_Bs[i], is_row_major(b_layout)); + stride_Cs[i] = + ck_tile::get_default_stride(M, N, stride_Cs[i], is_row_major(CLayout{})); a_m_k_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, K, stride_As[i], a_layout))); + ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout)))); b_k_n_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], b_layout))); + ck_tile::host_tensor_descriptor(K, N, stride_Bs[i], is_row_major(b_layout)))); c_m_n_tensors.push_back(ck_tile::HostTensor( - ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], CLayout{}))); + ck_tile::host_tensor_descriptor(M, N, stride_Cs[i], is_row_major(CLayout{})))); std::cout << "gemm[" << i << "]" << " a_m_k: " << a_m_k_tensors[i].mDesc << " b_k_n: " << b_k_n_tensors[i].mDesc @@ -177,8 +187,8 @@ int run_grouped_gemm_example_with_layouts(int argc, { for(int i = 0; i < group_count; ++i) { - ck_tile::HostTensor c_m_n_host_ref( - ck_tile::host_tensor_descriptor(Ms[i], Ns[i], stride_Cs[i], CLayout{})); + ck_tile::HostTensor c_m_n_host_ref(ck_tile::host_tensor_descriptor( + Ms[i], Ns[i], stride_Cs[i], is_row_major(CLayout{}))); c_m_n_host_ref.SetZero(); ck_tile::reference_gemm( a_m_k_tensors[i], b_k_n_tensors[i], c_m_n_host_ref); diff --git a/include/ck_tile/host/host_tensor.hpp b/include/ck_tile/host/host_tensor.hpp index 2babb2afe9..2047ad7793 100644 --- a/include/ck_tile/host/host_tensor.hpp +++ b/include/ck_tile/host/host_tensor.hpp @@ -679,12 +679,15 @@ struct HostTensor Data mData; }; -template -auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) +template +auto host_tensor_descriptor(std::size_t row, + std::size_t col, + std::size_t stride, + bool_constant) { using namespace ck_tile::literals; - if constexpr(std::is_same_v) + if constexpr(is_row_major) { return HostTensorDescriptor({row, col}, {stride, 1_uz}); } @@ -693,12 +696,15 @@ auto host_tensor_descriptor(std::size_t row, std::size_t col, std::size_t stride return HostTensorDescriptor({row, col}, {1_uz, stride}); } } -template -auto get_default_stride(std::size_t row, std::size_t col, std::size_t stride, TLayout layout) +template +auto get_default_stride(std::size_t row, + std::size_t col, + std::size_t stride, + bool_constant) { if(stride == 0) { - if constexpr(std::is_same_v) + if constexpr(is_row_major) { return col; }