diff --git a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc index 8fb947fa60..2d72b373bd 100644 --- a/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc +++ b/example/ck_tile/19_grouped_flatmm/run_grouped_flatmm_example.inc @@ -17,21 +17,17 @@ auto shuffle_b(const ck_tile::HostTensor& t) int n_ = t.get_lengths()[1]; int k_ = t.get_lengths()[0]; - if constexpr(GemmConfig::N_Warp_Tile == 32) - { - ck_tile::HostTensor t_view( - {n_ / 32, 32, k_ / GemmConfig::K_Warp_Tile, 2, GemmConfig::K_Warp_Tile / 2}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } - else - { - static_assert(GemmConfig::N_Warp_Tile == 16); - ck_tile::HostTensor t_view( - {n_ / 16, 16, k_ / GemmConfig::K_Warp_Tile, 4, GemmConfig::K_Warp_Tile / 4}); - std::copy(t.begin(), t.end(), t_view.begin()); - return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); - } + constexpr int N_Warp_Tile = GemmConfig::N_Warp_Tile; + constexpr int N_Warp = GemmConfig::N_Warp; + constexpr int KPerLane = GemmConfig::K_Warp_Tile / (64 / N_Warp_Tile); + + ck_tile::HostTensor t_view({n_ / N_Warp_Tile, + N_Warp_Tile, + k_ / (64 * KPerLane / N_Warp_Tile), + 64 / N_Warp_Tile, + KPerLane}); + std::copy(t.begin(), t.end(), t_view.begin()); + return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4}); } template @@ -414,7 +410,7 @@ int run_contiguous_grouped_flatmm_example_with_layouts( ck_tile::HostTensor c_m_n_tensor(ck_tile::HostTensor( ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(c_layout)))); - std::vector m_indices(std::size_t(M), -1); + std::vector m_indices(M); int indices_fill_start = 0; for(int i = 0; i < group_count; ++i) { @@ -428,9 +424,11 @@ int run_contiguous_grouped_flatmm_example_with_layouts( } ck_tile::FillUniformDistribution{-1.f, 1.f}(a_m_k_tensor); - ck_tile::FillUniformDistribution{-4.f, 4.f}(b_k_n_tensor); - c_m_n_tensor.SetZero(); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_k_n_tensor); + constexpr int N_Warp_Tile = GemmConfig::N_Warp_Tile; + assert(N % N_Warp_Tile == 0 && + "N must be divisible by N_Warp_Tile for contiguous grouped gemm"); ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n_tensor); std::unique_ptr a_m_k_dev_buf( @@ -444,6 +442,9 @@ int run_contiguous_grouped_flatmm_example_with_layouts( ck_tile::DeviceMem m_indices_dev_buf(M * sizeof(ck_tile::index_t)); m_indices_dev_buf.ToDevice(m_indices.data()); + a_m_k_dev_buf->ToDevice(a_m_k_tensor.data()); + b_shfl_dev_buf->ToDevice(b_shuffle_host.data()); + ck_tile::ContiguousGroupedFlatmmHostArgs kernal_args{ static_cast(m_indices_dev_buf.GetDeviceBuffer()), M,