mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 21:09:08 +00:00
Merge commit 'ec23be0b9d45ff9ca4135090bcd0269184c953a7' into develop
This commit is contained in:
@@ -55,9 +55,10 @@ struct FillUniformDistribution
|
||||
const auto total_bytes = total * sizeof(T_iter);
|
||||
|
||||
// max 80 threads; at least 2MB per thread
|
||||
const size_t available_cpu_cores = get_available_cpu_cores();
|
||||
const size_t num_thread =
|
||||
min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
|
||||
const size_t available_cpu_cores = get_available_cpu_cores();
|
||||
constexpr uint64_t MAX_THREAD_COUNT = 80;
|
||||
const size_t num_thread = min(
|
||||
MAX_THREAD_COUNT, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
|
||||
constexpr size_t BLOCK_BYTES = 64;
|
||||
constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter);
|
||||
const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES);
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cinttypes>
|
||||
#include <cstdlib>
|
||||
#include <thread>
|
||||
|
||||
@@ -28,7 +29,7 @@ CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
|
||||
output.get_num_of_dimension() == NDimSpatial + 3))
|
||||
{
|
||||
|
||||
printf("%lu %lu %lu",
|
||||
printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
|
||||
input.get_num_of_dimension(),
|
||||
weight.get_num_of_dimension(),
|
||||
output.get_num_of_dimension());
|
||||
|
||||
@@ -246,9 +246,11 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
else // A is in RowMajor
|
||||
{
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
@@ -442,11 +444,13 @@ struct UniversalGemmBasePolicy
|
||||
}
|
||||
else // B is Column Major
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
|
||||
Reference in New Issue
Block a user