Update unsigned long literals and format specifiers to work correctly in Windows (#3483)

Previously, the code used unsigned long for literals and format specifiers to represent 64-bit unsigned values. While this worked on Linux, it caused compatibility issues on Windows.
The C++ standard does not guarantee that long is 64 bits. On LP64 systems (e.g., Linux), long maps to 64-bit values, but on LLP64 systems (e.g., Windows), long maps to 32-bit values. This discrepancy led to incorrect behavior when assuming unsigned long was always 64-bit.
This commit updates all relevant literals and format specifiers to explicitly use 64-bit unsigned types, ensuring consistent behavior across platforms.
This commit is contained in:
John Afaganis
2026-01-02 22:16:41 -07:00
committed by GitHub
parent 4670df5ca6
commit ec23be0b9d
6 changed files with 22 additions and 14 deletions

View File

@@ -246,9 +246,11 @@ struct UniversalGemmBasePolicy
}
else // A is in RowMajor
{
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto DataTypeSize = sizeof(ADataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto MLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
@@ -442,11 +444,13 @@ struct UniversalGemmBasePolicy
}
else // B is Column Major
{
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto BK0 = number<KPerBlock / KPack>{};
constexpr auto DataTypeSize = sizeof(BDataType);
constexpr uint64_t MinLdsLayer = 1ULL;
constexpr auto NLdsLayer =
max(1UL, get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
max(MinLdsLayer,
get_n_lds_banks() * get_n_words_per_128b() / KPerBlock / DataTypeSize);
constexpr index_t NBanks = get_n_lds_banks();
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");