debug mixed_prec flatmm

This commit is contained in:
Feng Shijie
2025-08-07 09:22:04 +00:00
parent 0ba513b148
commit 3dea10a277
10 changed files with 2193 additions and 44 deletions

View File

@@ -28,10 +28,60 @@ struct FlatmmProblem
index_t stride_C;
};
template <int SharedGranularity>
template <int SharedGranularityMN, int SharedGranularityK = 0>
struct FlatmmScalePointer
{
static constexpr int granularity = SharedGranularity;
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = SharedGranularityK;
const float* ptr;
index_t scale_stride = 1;
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, index_t stride)
: ptr(ptr_), scale_stride(stride)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
// if constexpr(GranularityMN == 0)
// {
// ret.scalar = scalar;
// }
// else if constexpr(GranularityMN == 1)
// {
// ret.ptr = ptr + offset;
// }
// else
// {
// ret.ptr = ptr + offset / GranularityMN;
// }
return ret;
}
CK_TILE_HOST_DEVICE float operator[](index_t i) const
{
if constexpr(GranularityMN == 1)
{
return ptr[i];
}
else
{
return ptr[i / GranularityMN];
}
}
};
template <int SharedGranularityMN>
struct FlatmmScalePointer<SharedGranularityMN, 0>
{
static constexpr int GranularityMN = SharedGranularityMN;
static constexpr int GranularityK = 0;
static_assert(GranularityMN != 0);
union
{
@@ -42,50 +92,63 @@ struct FlatmmScalePointer
CK_TILE_HOST_DEVICE FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE FlatmmScalePointer(float scalar_) : scalar(scalar_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_) : ptr(ptr_) {}
CK_TILE_HOST_DEVICE FlatmmScalePointer(const float* ptr_, [[maybe_unused]] index_t stride)
: ptr(ptr_)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer operator+(index_t offset) const
{
FlatmmScalePointer ret;
if constexpr(granularity == 0)
if constexpr(GranularityMN == 0)
{
ret.scalar = scalar;
}
else if constexpr(granularity == 1)
else if constexpr(GranularityMN == 1)
{
ret.ptr = ptr + offset;
}
else
{
ret.ptr = ptr + offset / granularity;
ret.ptr = ptr + offset / GranularityMN;
}
return ret;
}
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
CK_TILE_HOST_DEVICE float operator[](index_t i) const
{
if constexpr(granularity == 0)
if constexpr(GranularityMN == 0)
{
return scalar;
}
else if constexpr(granularity == 1)
else if constexpr(GranularityMN == 1)
{
return ptr[i];
}
else
{
return ptr[i / granularity];
return ptr[i / GranularityMN];
}
}
};
// shared granularity = -1 means no scale
// shared granularityMN = -1 means no scale
template <>
struct FlatmmScalePointer<-1>
struct FlatmmScalePointer<-1, 0>
{
static constexpr int granularity = -1;
static constexpr int GranularityMN = -1;
static constexpr int GranularityK = 0;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer() = default;
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float scalar_) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float* ptr_) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(float) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*) {}
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer(const float*, [[maybe_unused]] index_t stride)
{
}
CK_TILE_HOST_DEVICE FlatmmScalePointer& advance() { return *this; }
CK_TILE_HOST_DEVICE constexpr FlatmmScalePointer operator+(index_t) const
{
@@ -150,7 +213,6 @@ struct BaseFlatmmHostArgs
index_t k_batch;
};
template <class ScaleM = FlatmmScalePointer<-1>,
class ScaleN = FlatmmScalePointer<-1>,
index_t NumDTensor = 0>
@@ -558,9 +620,9 @@ struct FlatmmKernel
}
}();
index_t kFlatK = FlatmmPipeline::flatKPerWarp * (kargs.K /
BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
index_t kFlatK =
FlatmmPipeline::flatKPerWarp * (kargs.K / BlockGemmShape::WarpTile::at(I2));
index_t kFlatN = kargs.N * kargs.K / kFlatK;
const auto& b_flat_tensor_view = [&]() {
return make_naive_tensor_view<address_space_enum::global>(
b_flat_ptr,
@@ -776,7 +838,7 @@ struct FlatmmKernel
a_block_window, b_flat_block_window, num_loop, smem_ptr_ping, smem_ptr_pong);
// Run Epilogue Pipeline
if constexpr(ScaleM::granularity != -1 || ScaleN::granularity != -1)
if constexpr(ScaleM::GranularityMN != -1 || ScaleN::GranularityMN != -1)
{
auto& c_block_window = gemm_tile_windows.at(I3);
EpiloguePipeline{}.template