Merge commit 'ec23be0b9d45ff9ca4135090bcd0269184c953a7' into develop

This commit is contained in:
assistant-librarian[bot]
2026-01-03 06:16:07 +00:00
parent 0b05cd0351
commit d8734393c7
6 changed files with 22 additions and 14 deletions

View File

@@ -55,9 +55,10 @@ struct FillUniformDistribution
const auto total_bytes = total * sizeof(T_iter);
// max 80 threads; at least 2MB per thread
const size_t available_cpu_cores = get_available_cpu_cores();
const size_t num_thread =
min(80UL, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
const size_t available_cpu_cores = get_available_cpu_cores();
constexpr uint64_t MAX_THREAD_COUNT = 80;
const size_t num_thread = min(
MAX_THREAD_COUNT, available_cpu_cores, integer_divide_ceil(total_bytes, 0x200000UL));
constexpr size_t BLOCK_BYTES = 64;
constexpr size_t BLOCK_SIZE = BLOCK_BYTES / sizeof(T_iter);
const size_t num_blocks = integer_divide_ceil(total_bytes, BLOCK_BYTES);

View File

@@ -3,6 +3,7 @@
#pragma once
#include <cinttypes>
#include <cstdlib>
#include <thread>
@@ -28,7 +29,7 @@ CK_TILE_HOST void reference_grouped_conv_bwd_data(HostTensor<InDataType>& input,
output.get_num_of_dimension() == NDimSpatial + 3))
{
printf("%lu %lu %lu",
printf("%" PRIu64 " %" PRIu64 " %" PRIu64,
input.get_num_of_dimension(),
weight.get_num_of_dimension(),
output.get_num_of_dimension());

View File

@@ -246,9 +246,11 @@ struct UniversalGemmBasePolicy
}
else // A is in RowMajor
{
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
@@ -442,11 +444,13 @@ struct UniversalGemmBasePolicy
}
else // B is Column Major
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto NLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");