Skip to content

Faster floating-point array quantization #720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 68 additions & 1 deletion lib/test/apyfloatarray/test_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import pytest

from apytypes import APyFixed, APyFloat, APyFloatArray, QuantizationMode
from apytypes import APyFixed, APyFloat, APyFloatArray, QuantizationMode, fp


@pytest.mark.float_array
Expand Down Expand Up @@ -370,6 +370,73 @@ def test_cast():
assert fp_array.cast(man_bits=5).is_identical(answer)


@pytest.mark.float_array
@pytest.mark.parametrize("exp_bits", range(4, 8))
@pytest.mark.parametrize("man_bits", range(7, 12))
@pytest.mark.parametrize(
"quantization",
[
QuantizationMode.RND_CONV,
QuantizationMode.JAM_UNBIASED,
QuantizationMode.RND_CONV_ODD,
QuantizationMode.RND_ZERO,
QuantizationMode.RND_INF,
QuantizationMode.RND_MIN_INF,
QuantizationMode.TIES_AWAY,
QuantizationMode.TIES_EVEN,
QuantizationMode.TIES_NEG,
QuantizationMode.TIES_POS,
QuantizationMode.TO_AWAY,
QuantizationMode.TO_ZERO,
QuantizationMode.TO_NEG,
QuantizationMode.TO_POS,
],
)
def test_cast_array_parametrized(exp_bits, man_bits, quantization):
vals = [0, float("inf"), float("nan"), 1e-2, 1, 100, -1e-2, -1, -100]
fpvals = [fp(v, exp_bits, man_bits) for v in vals]
fparray = fp(vals, exp_bits, man_bits)

# Check that starting values are the same
for i, v in enumerate(fpvals):
assert v.is_identical(fparray[i])

# Decrease word length
casted_fparray = fparray.cast(exp_bits - 1, man_bits - 2, quantization=quantization)
for i, v in enumerate(fpvals):
assert v.cast(
exp_bits - 1, man_bits - 2, quantization=quantization
).is_identical(casted_fparray[i])

# Decrease mantissa word length
casted_fparray = fparray.cast(exp_bits, man_bits - 2, quantization=quantization)
for i, v in enumerate(fpvals):
assert v.cast(exp_bits, man_bits - 2, quantization=quantization).is_identical(
casted_fparray[i]
)

# Decrease exponent word length
casted_fparray = fparray.cast(exp_bits - 1, man_bits, quantization=quantization)
for i, v in enumerate(fpvals):
assert v.cast(exp_bits - 1, man_bits, quantization=quantization).is_identical(
casted_fparray[i]
)

# Increase word length
casted_fparray = fparray.cast(exp_bits + 1, man_bits + 2, quantization=quantization)
for i, v in enumerate(fpvals):
assert v.cast(
exp_bits + 1, man_bits + 2, quantization=quantization
).is_identical(casted_fparray[i])

# No change in word length
casted_fparray = fparray.cast(exp_bits, man_bits, quantization=quantization)
for i, v in enumerate(fpvals):
assert v.cast(exp_bits, man_bits, quantization=quantization).is_identical(
casted_fparray[i]
)


@pytest.mark.float_array
def test_python_sum():
fx_array = APyFloatArray.from_float([1, 2, 3, 4, 5, 6], exp_bits=10, man_bits=10)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ unfixable = []
dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"

[tool.ruff.lint.per-file-ignores]
"lib/test/*" = ["D", "ANN"]
"lib/test/*" = ["D", "ANN", "PLC0415"]
"examples/*" = ["ANN"]
"docs/*" = ["ANN"]
"comparison/*" = ["ANN"]
Expand Down
167 changes: 152 additions & 15 deletions src/apyfloat_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,15 +392,11 @@
QNTZ_FUNC_SIGNATURE qntz_func
)
{
// Handle special values first
// Handle special values first (zero handled later)
if (is_max_exponent(src, src_spec)) {
if (is_nan(src, src_spec)) {
return { src.sign, exp_t((1ULL << dst_spec.exp_bits) - 1), 1 }; // NaN
} else if (is_inf(src, src_spec)) {
return { src.sign, exp_t((1ULL << dst_spec.exp_bits) - 1), 0 }; // inf
}
} else if (is_zero(src)) {
return { src.sign, 0, 0 };
return { src.sign,
exp_t((1ULL << dst_spec.exp_bits) - 1),
man_t(src.man == 0 ? 0 : 1) }; // inf or nan
}

// Initial values for cast data
Expand All @@ -409,6 +405,9 @@

// Normalize the exponent and mantissa if converting from subnormal
if (src.exp == 0) {
if (src.man == 0) { // Zero
return { src.sign, 0, 0 };
}
exp_t subn_adjustment = count_trailing_bits(src.man);
man_t remainder = src.man % (1ULL << subn_adjustment);
exp = exp - src_spec.man_bits + subn_adjustment;
Expand Down Expand Up @@ -460,20 +459,111 @@
return res;
}

//! Cast a floating-point value from one format to another
//! using many pre-computed constants making it suitable for arrays.
template <typename QNTZ_FUNC_SIGNATURE>
[[maybe_unused]] static APY_INLINE APyFloatData array_floating_point_cast(
const APyFloatData& src,
const APyFloatSpec& src_spec,
const APyFloatSpec& dst_spec,
QuantizationMode qntz,
QNTZ_FUNC_SIGNATURE qntz_func,
const exp_t SRC_MAX_EXP,
const exp_t DST_MAX_EXP,
const man_t SRC_LEADING_ONE,
const man_t DST_LEADING_ONE,
const int SPEC_MAN_BITS_DELTA,
const int SPEC_MAN_BITS_DELTA_REV,
const man_t SRC_HIDDEN_ONE,
const man_t FINAL_STICKY,
const std::int64_t BIAS_DELTA
)
{
// Handle special values first (zero handled later)
if (src.exp == SRC_MAX_EXP) {
return { src.sign, DST_MAX_EXP, man_t(src.man == 0 ? 0 : 1) }; // inf or nan
}

// Initial values for cast data
man_t man = src.man;
std::int64_t exp = std::int64_t(src.exp) - BIAS_DELTA + (src.exp == 0);

// Normalize the exponent and mantissa if converting from subnormal
if (src.exp == 0) {
if (src.man == 0) { // Zero
return { src.sign, 0, 0 };
}
exp_t subn_adjustment = count_trailing_bits(src.man);
man_t remainder = src.man % (1ULL << subn_adjustment);
exp = exp - src_spec.man_bits + subn_adjustment;
man = remainder << (src_spec.man_bits - subn_adjustment);
}

// Check if the result will be subnormal after cast
if (exp <= 0) {
if (exp < -std::int64_t(dst_spec.man_bits)) {
// Exponent to small after rounding
return { src.sign, 0, quantize_close_to_zero(src.sign, qntz) };
}

Check warning on line 507 in src/apyfloat_util.h

View check run for this annotation

Codecov / codecov/patch

src/apyfloat_util.h#L507

Added line #L507 was not covered by tests

const int MAN_BITS_DELTA = 1 - exp - SPEC_MAN_BITS_DELTA;
man |= SRC_HIDDEN_ONE; // Add the hidden one
if (MAN_BITS_DELTA <= 0) {
return { src.sign, 0, (man << -MAN_BITS_DELTA) };

Check warning on line 512 in src/apyfloat_util.h

View check run for this annotation

Codecov / codecov/patch

src/apyfloat_util.h#L512

Added line #L512 was not covered by tests
} else { /* man_bits_delta > 0 */
const man_t STICKY = (1ULL << (MAN_BITS_DELTA - 1)) - 1;
APyFloatData res = { src.sign, 0, man };
qntz_func(
res.man,
res.exp,
DST_MAX_EXP,
MAN_BITS_DELTA,
res.sign,
SRC_LEADING_ONE,
STICKY
);
return res;
}
}

// Quantize the mantissa and return
if (SPEC_MAN_BITS_DELTA >= 0) {
if (exp >= DST_MAX_EXP) {
if (do_infinity(qntz, src.sign)) {
return { src.sign, DST_MAX_EXP, 0 }; // inf
} else {
return { src.sign,
DST_MAX_EXP - 1,
DST_LEADING_ONE - 1 }; // largest normal
}
} else { /* exp < DST_MAX_EXP */
return { src.sign, exp_t(exp), (man << SPEC_MAN_BITS_DELTA) };
}
}

APyFloatData res { src.sign, exp_t(exp), man };
qntz_func(
res.man,
res.exp,
DST_MAX_EXP,
SPEC_MAN_BITS_DELTA_REV,
res.sign,
DST_LEADING_ONE,
FINAL_STICKY
);
return res;
}

//! Cast a floating-point number when it is known that no quantization happens
[[maybe_unused]] static APY_INLINE APyFloatData floating_point_cast_no_quant(
const APyFloatData& src, const APyFloatSpec& src_spec, const APyFloatSpec& dst_spec
)
{
// Handle special values first
if (is_max_exponent(src, src_spec.exp_bits)) {
if (src.man) {
return { src.sign, exp_t((1ULL << dst_spec.exp_bits) - 1), 1 }; // NaN
} else {
return { src.sign, exp_t((1ULL << dst_spec.exp_bits) - 1), 0 }; // +-inf
}
} else if (is_zero(src)) {
return { src.sign, 0, 0 };
return { src.sign,
exp_t((1ULL << dst_spec.exp_bits) - 1),
man_t(src.man == 0 ? 0 : 1) }; // Int or NaN
}

// Initial value for exponent
Expand All @@ -482,6 +572,9 @@
// Adjust the exponent and mantissa if convertering from a subnormal
man_t new_man;
if (src.exp == 0) {
if (src.man == 0) {
return { src.sign, 0, 0 };
}
const exp_t subn_adjustment = count_trailing_bits(src.man);
if (new_exp + subn_adjustment < src_spec.man_bits) {
// The result remains subnormal
Expand All @@ -501,6 +594,50 @@
return { src.sign, exp_t(new_exp), man_t(new_man) };
}

//! Cast a floating-point number when it is known that no quantization happens
//! using pre-computed values suitable for arrays
[[maybe_unused]] static APY_INLINE APyFloatData array_floating_point_cast_no_quant(
const APyFloatData& src,
const APyFloatSpec& src_spec,
const exp_t SRC_MAX_EXP,
const exp_t DST_MAX_EXP,
const int SPEC_MAN_BITS_DELTA,
const std::int64_t BIAS_DELTA
)
{
// Handle special values first
if (src.exp == SRC_MAX_EXP) {
return { src.sign, DST_MAX_EXP, man_t(src.man == 0 ? 0 : 1) }; // Int or NaN
}

// Initial value for exponent
std::int64_t new_exp = std::int64_t(src.exp) - BIAS_DELTA + (src.exp == 0);

// Adjust the exponent and mantissa if convertering from a subnormal
man_t new_man;
if (src.exp == 0) {
if (src.man == 0) {
return { src.sign, 0, 0 };
}
const exp_t subn_adjustment = count_trailing_bits(src.man);
if (new_exp + subn_adjustment < src_spec.man_bits) {
// The result remains subnormal
new_man = src.man << new_exp;
new_exp = 0;

Check warning on line 626 in src/apyfloat_util.h

View check run for this annotation

Codecov / codecov/patch

src/apyfloat_util.h#L625-L626

Added lines #L625 - L626 were not covered by tests
} else {
// The result becomes normal
new_man = src.man << (src_spec.man_bits - subn_adjustment);
new_man &= (1ULL << src_spec.man_bits) - 1;
new_exp = new_exp + subn_adjustment - src_spec.man_bits;
}
} else {
new_man = src.man;
}

new_man <<= SPEC_MAN_BITS_DELTA;
return { src.sign, exp_t(new_exp), man_t(new_man) };
}

//! Return the bit pattern of a floating-point data field. No checks on bit width is
//! done.
[[maybe_unused]] static APY_INLINE std::uint64_t
Expand Down
58 changes: 40 additions & 18 deletions src/apyfloatarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1240,35 +1240,57 @@ APyFloatArray APyFloatArray::_cast(
return *this;
}

APyFloatArray result(_shape, new_exp_bits, new_man_bits, new_bias);

const exp_t SRC_MAX_EXP = (1ULL << exp_bits) - 1;
const exp_t DST_MAX_EXP = (1ULL << new_exp_bits) - 1;
const int SPEC_MAN_BITS_DELTA = new_man_bits - man_bits;
const std::int64_t BIAS_DELTA = std::int64_t(bias) - std::int64_t(new_bias);

// If longer word lengths, use simpler/faster method
if (new_exp_bits >= exp_bits && new_man_bits >= man_bits) {
return cast_no_quant(new_exp_bits, new_man_bits, new_bias);
for (std::size_t i = 0; i < _data.size(); i++) {
result._data[i] = array_floating_point_cast_no_quant(
_data[i],
spec(),
SRC_MAX_EXP,
DST_MAX_EXP,
SPEC_MAN_BITS_DELTA,
BIAS_DELTA
);
}
return result;
}

APyFloatArray result(_shape, new_exp_bits, new_man_bits, new_bias);
const auto quantization_func = get_qntz_func(quantization);
const man_t SRC_LEADING_ONE = (1ULL << man_bits);
const man_t DST_LEADING_ONE = (1ULL << new_man_bits);
const man_t SRC_HIDDEN_ONE = (1ULL << man_bits);
const int SPEC_MAN_BITS_DELTA_REV = -SPEC_MAN_BITS_DELTA;
const man_t FINAL_STICKY = (1ULL << (SPEC_MAN_BITS_DELTA_REV - 1)) - 1;

APyFloat caster(exp_bits, man_bits, bias);
for (std::size_t i = 0; i < _data.size(); i++) {
caster.set_data(_data[i]);
result._data[i]
= caster.checked_cast(new_exp_bits, new_man_bits, new_bias, quantization)
.get_data();
result._data[i] = array_floating_point_cast(
_data[i],
spec(),
result.spec(),
quantization,
quantization_func,
SRC_MAX_EXP,
DST_MAX_EXP,
SRC_LEADING_ONE,
DST_LEADING_ONE,
SPEC_MAN_BITS_DELTA,
SPEC_MAN_BITS_DELTA_REV,
SRC_HIDDEN_ONE,
FINAL_STICKY,
BIAS_DELTA
);
}

return result;
}

APyFloatArray APyFloatArray::cast_no_quant(
std::uint8_t new_exp_bits, std::uint8_t new_man_bits, std::optional<exp_t> new_bias
) const
{
APyFloatArray result(_shape, new_exp_bits, new_man_bits, new_bias);
for (std::size_t i = 0; i < _data.size(); i++) {
result._data[i] = floating_point_cast_no_quant(_data[i], spec(), result.spec());
}
return result;
}

// Evaluate the inner between two vectors. This method assumes that the the shape of
// both `*this` and `rhs` are equally long. Anything else is undefined behaviour.
APyFloat APyFloatArray::checked_inner_product(const APyFloatArray& rhs) const
Expand Down
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy