r"""
Decibel Units
-------------
To create a dimensionless quantity in decibels, use the ``u.dB(1)`` unit rather than
``u.dB``. For example, ``3.01 * u.dB(1)``. The (1) informs Astropy that the reference
level is 1, which allows conversion from decibels to linear scale via
``.to(u.dimensionless)``. A bare ``u.dB`` has no defined reference level and Astropy
will refuse to convert it to ``u.dimensionless``.
For quantities with physical dimensions in decibels, use ``u.dB(unit)``. For example,
``3.01 * u.dB(u.W)``. Or use one of the aliases defined in this module for common cases
like ``u.dBW`` or ``u.dBm``.
Wavelength
----------
The relationship between wavelength and frequency is given by:
.. math::
\lambda = \frac{c}{f}
where:
* :math:`c` is the speed of light (299,792,458 m/s)
* :math:`f` is the frequency in Hz
Return Loss to VSWR
-------------------
The conversion from return loss in decibels to voltage standing wave ratio (VSWR) is
done using:
.. math::
\text{VSWR} = \frac{1 + |\Gamma|}{1 - |\Gamma|}
where:
* :math:`|\Gamma|` is the magnitude of the reflection coefficient
* :math:`|\Gamma| = 10^{-\frac{\text{RL}}{20}}`
* :math:`\text{RL}` is the return loss in dB
VSWR to Return Loss
-------------------
The conversion from voltage standing wave ratio (VSWR) to return loss in decibels is
done using:
.. math::
\text{RL} = -20 \log_{10}\left(\frac{\text{VSWR} - 1}{\text{VSWR} + 1}\right)
where:
* :math:`\text{VSWR}` is the voltage standing wave ratio
* :math:`\text{RL}` is the return loss in dB
"""
import dataclasses
import types
from collections.abc import Callable
from functools import wraps
from inspect import signature
from typing import (
Annotated,
Any,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
import astropy.constants as constants
import astropy.units as u
import numpy as np
from astropy.units import Quantity
# Type variable for enforce_units decorator - accepts only functions or classes
FuncOrClass = TypeVar("FuncOrClass", Callable[..., Any], type)
if not hasattr(u, "dBHz"): # pragma: no cover
u.dBHz = u.dB(u.Hz)
if not hasattr(u, "dBW"): # pragma: no cover
u.dBW = u.dB(u.W)
if not hasattr(u, "dBm"): # pragma: no cover
u.dBm = u.dB(u.mW)
if not hasattr(u, "dBK"): # pragma: no cover
u.dBK = u.dB(u.K)
if not hasattr(u, "dB_per_K"): # pragma: no cover
u.dB_per_K = u.dB(1 / u.K)
if not hasattr(u, "dimensionless"): # pragma: no cover
u.dimensionless = u.dimensionless_unscaled
# Using u.dB(1) allows conversion from decibels to u.dimensionless_unscaled. The (1)
# informs Astropy that the value is decibels relative to 1; without it a bare u.dB has
# no defined reference point.
Decibels = Annotated[Quantity, u.dB(1)]
DecibelWatts = Annotated[Quantity, u.dB(u.W)]
DecibelMilliwatts = Annotated[Quantity, u.dB(u.mW)]
DecibelKelvins = Annotated[Quantity, u.dB(u.K)]
DecibelPerKelvin = Annotated[Quantity, u.dB(1 / u.K)]
Power = Annotated[Quantity, u.W]
PowerDensity = Annotated[Quantity, u.W / u.Hz]
Frequency = Annotated[Quantity, u.Hz]
Wavelength = Annotated[Quantity, u.m]
Dimensionless = Annotated[Quantity, u.dimensionless_unscaled]
Distance = Annotated[Quantity, u.m]
Temperature = Annotated[Quantity, u.K]
Length = Annotated[Quantity, u.m]
DecibelHertz = Annotated[Quantity, u.dB(u.Hz)]
Angle = Annotated[Quantity, u.rad]
SolidAngle = Annotated[Quantity, u.sr]
Time = Annotated[Quantity, u.s]
# Module-level flag to enable return unit checking (for tests)
_RETURN_UNITS_CHECK_ENABLED = False
def _extract_annotated_from_hint(hint: Any) -> tuple[type, u.Unit] | None:
"""
Extract Annotated type and unit from a type hint, handling optional parameters.
Parameters
----------
hint : Any
Type hint that may be Annotated directly or a Union containing Annotated
Returns
-------
tuple[type, u.Unit] | None
(quantity_type, unit) if Annotated type found, None otherwise
"""
if hint is None: # pragma: no cover
return None
# Check if hint is directly Annotated
if get_origin(hint) is Annotated:
args = get_args(hint)
if len(args) >= 2:
return args[0], args[1]
# Check if hint is a Union (including PEP 604 X | Y syntax)
origin = get_origin(hint)
if origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
# Look through union arguments for Annotated types
for arg in get_args(hint):
if get_origin(arg) is Annotated:
annotated_args = get_args(arg)
if len(annotated_args) >= 2:
return annotated_args[0], annotated_args[1]
return None
def _extract_tuple_annotations(
hint: Any,
) -> list[tuple[tuple[type, u.Unit] | None, Any]] | None:
"""
Extract annotations from tuple type hints.
Parameters
----------
hint : Any
Type hint that may be a tuple containing Annotated types
Returns
-------
list[tuple[tuple[type, u.Unit] | None, Any]] | None
List of ((quantity_type, unit), original_hint) for each tuple element,
where the first element is None if not annotated. Returns None if hint is not a
tuple.
"""
origin = get_origin(hint)
if origin is tuple:
args = get_args(hint)
annotations = []
for arg in args:
annotated_info = _extract_annotated_from_hint(arg)
annotations.append((annotated_info, arg))
return annotations
return None
def _validate_tuple_return(result, expected_annotations):
"""
Validate tuple return values against their type annotations.
Parameters
----------
result : Any
The actual return value (should be a tuple)
expected_annotations : list[tuple[tuple[type, u.Unit] | None, Any]]
List of ((quantity_type, unit), original_hint) for each tuple element
"""
if not isinstance(result, tuple):
raise TypeError("Expected tuple return value.")
if len(result) != len(expected_annotations):
raise TypeError(
f"Expected tuple with {len(expected_annotations)} elements, "
f"got {len(result)} elements."
)
for i, (value, (annotation, original_hint)) in enumerate(
zip(result, expected_annotations, strict=False)
):
if annotation is not None: # Only check annotated elements
_, expected_unit = annotation
if value is None:
# Check if None is allowed (Optional type)
origin = get_origin(original_hint)
if not (
origin in (Union, getattr(types, "UnionType", ()))
and type(None) in get_args(original_hint)
):
raise TypeError(
f"tuple[{i}] is None but not annotated as Optional."
)
continue
if not isinstance(value, Quantity):
raise TypeError(f"tuple[{i}] must be an astropy Quantity.")
if value.unit != expected_unit:
raise u.UnitConversionError(
f"tuple[{i}] unit {value.unit} != annotated {expected_unit}."
)
def _convert_parameter_units(name: str, value: Any, expected_unit: u.Unit) -> Quantity:
"""
Convert a parameter value to the expected unit.
Parameters
----------
name : str
Parameter name for error messages
value : Any
Parameter value to convert
expected_unit : u.Unit
Expected unit for the parameter
Returns
-------
Quantity
Converted quantity with the expected unit
Raises
------
TypeError
If value is not a Quantity
UnitConversionError
If units are incompatible
"""
if not isinstance(value, Quantity):
raise TypeError(
f"Parameter '{name}' must be provided as an astropy Quantity with unit "
f"compatible with {expected_unit}, not a raw number."
)
# Units like deg_C are not automatically convertible to/from K
if expected_unit.is_equivalent(u.K, equivalencies=u.temperature()):
equivalencies = u.temperature()
elif value.unit == u.dB:
# Allows conversion from u.dB to u.dimensionless_unscaled as if
# value.unit was u.dB(1)
equivalencies = u.logarithmic()
else:
equivalencies = []
try:
return value.to(expected_unit, equivalencies=equivalencies)
except u.UnitConversionError as e:
raise u.UnitConversionError(
f"Parameter '{name}' requires unit compatible with {expected_unit}, "
f"but got {value.unit}. Original error: {e}"
) from e
def _validate_single_return(result: Any, expected_unit: u.Unit, ret_hint: Any) -> None:
"""
Validate a single return value against its expected unit.
Parameters
----------
result : Any
The actual return value
expected_unit : u.Unit
Expected unit for the return value
ret_hint : Any
Original return type hint for Optional checking
Raises
------
TypeError
If result is None when not Optional, or not a Quantity
UnitConversionError
If units don't match exactly
"""
if result is None:
# Check if None is allowed (Optional type)
origin = get_origin(ret_hint)
if not (
origin in (Union, getattr(types, "UnionType", ()))
and type(None) in get_args(ret_hint)
):
raise TypeError("Return value is None but not annotated as Optional.")
return
if not isinstance(result, Quantity):
raise TypeError("Return value must be an astropy Quantity.")
if result.unit != expected_unit:
raise u.UnitConversionError(
f"Return unit {result.unit} != annotated {expected_unit}."
)
def _validate_return_units(result: Any, ret_hint: Any) -> None:
"""
Validate return value units based on type hint.
Parameters
----------
result : Any
The actual return value
ret_hint : Any
Return type hint
"""
if not _RETURN_UNITS_CHECK_ENABLED or ret_hint is None:
return
# Try tuple support first
tuple_annotations = _extract_tuple_annotations(ret_hint)
if tuple_annotations is not None:
_validate_tuple_return(result, tuple_annotations)
return
# Single annotated quantity
annotated_info = _extract_annotated_from_hint(ret_hint)
if annotated_info is not None:
_, expected_unit = annotated_info
_validate_single_return(result, expected_unit, ret_hint)
def _process_parameter(name: str, value: Any, hint: Any) -> Quantity | Any:
"""
Process a single parameter, converting units if needed.
Parameters
----------
name : str
Parameter name
value : Any
Parameter value
hint : Any
Type hint for the parameter
Returns
-------
Quantity | Any
Converted parameter value, or original value if not annotated
"""
annotated_info = _extract_annotated_from_hint(hint)
if annotated_info is None:
return value # Not an annotated parameter
_, expected_unit = annotated_info
# Handle None values for optional parameters
if value is None:
return value
return _convert_parameter_units(name, value, expected_unit)
def _wrap_function_with_unit_enforcement(
func: Callable[..., Any],
) -> Callable[..., Any]:
"""
Internal helper to wrap a function with unit enforcement logic.
This is the core unit enforcement logic extracted to be reusable
for both regular functions and dataclass __init__ methods.
"""
sig = signature(func)
hints = get_type_hints(func, include_extras=True)
@wraps(func)
def wrapper(*args, **kwargs):
bound = sig.bind(*args, **kwargs)
bound.apply_defaults()
# Process all parameters
for name, value in bound.arguments.items():
hint = hints.get(name)
bound.arguments[name] = _process_parameter(name, value, hint)
# Execute the function
result = func(*bound.args, **bound.kwargs)
# Validate return value units
_validate_return_units(result, hints.get("return"))
return result
return wrapper
[docs]
def enforce_units(func_or_class: FuncOrClass) -> FuncOrClass:
"""
Decorator to enforce the units specified in function parameter type annotations.
This decorator enforces some unit consistency rules for function parameters that
annotated with one of the ``Annotated`` types in this module:
* The argument must be a ``Quantity`` object.
* The argument must be provided with a compatible unit. For example, a ``Frequency``
argument's units can be ``u.Hz``, ``u.MHz``, ``u.GHz``, etc. but not ``u.m``,
``u.K``, or any other non-frequency unit.
In addition to the above, the value of any ``Annotated`` argument will be converted
automatically to the unit specified in for that type. For example, the ``Angle``
type will be converted to ``u.rad``, even if the argument is provided with a unit of
``u.deg``. This allows functions to flexibly handle compatible units while keeping
tedious unit conversion logic out of the function body.
When applied to a dataclass, this decorator will wrap the ``__init__`` method
to enforce units on dataclass field assignments.
Parameters
----------
func_or_class : callable or class
The function or dataclass to wrap.
Returns
-------
callable or class
The wrapped function or modified dataclass with unit enforcement.
Raises
------
UnitConversionError
If any argument has incompatible units.
TypeError
If an ``Annotated`` argument is not an Astropy ``Quantity`` object.
"""
# Check if this is a class
if isinstance(func_or_class, type):
if dataclasses.is_dataclass(func_or_class):
# Handle dataclass case: wrap the __init__ method
original_init = func_or_class.__init__
wrapped_init = _wrap_function_with_unit_enforcement(original_init)
func_or_class.__init__ = wrapped_init
return func_or_class
else:
# Regular class - this is probably a mistake
raise TypeError(
f"@enforce_units should not be applied to regular classes. "
f"Apply it directly to the __init__ method instead:\n\n"
f"class {func_or_class.__name__}:\n"
f" @enforce_units\n"
f" def __init__(self, ...):\n"
f" ..."
)
else:
# Handle regular function case
return _wrap_function_with_unit_enforcement(func_or_class)
[docs]
@enforce_units
def wavelength(frequency: Frequency) -> Wavelength:
r"""
Convert frequency to wavelength.
Parameters
----------
frequency : Quantity
Frequency quantity (e.g., in Hz)
Returns
-------
Quantity
Wavelength in meters
Raises
------
UnitConversionError
If the input quantity has incompatible units
"""
return constants.c / frequency.to(u.Hz)
[docs]
@enforce_units
def frequency(wavelength: Wavelength) -> Frequency:
r"""
Convert wavelength to frequency.
Parameters
----------
wavelength : Quantity
Wavelength quantity (e.g., in meters)
Returns
-------
Quantity
Frequency in hertz
Raises
------
UnitConversionError
If the input quantity has incompatible units
"""
return constants.c / wavelength.to(u.m)
[docs]
@enforce_units
def return_loss_to_vswr(return_loss: Dimensionless) -> Dimensionless:
r"""
Convert a return loss in decibels to voltage standing wave ratio (VSWR).
Parameters
----------
return_loss : Dimensionless
Return loss. Must be >= 1 if provided as dimensionless or >= 0 if provided in
decibels. Use np.inf for a perfect match.
Returns
-------
Dimensionless
VSWR (>= 1)
Raises
------
ValueError
If return_loss is < 0 dB
"""
if np.any(return_loss.value < 1):
raise ValueError("Return loss must be >= 1.")
gamma = 1 / np.sqrt(return_loss)
return (1 + gamma) / (1 - gamma)
[docs]
@enforce_units
def vswr_to_return_loss(vswr: Dimensionless) -> Decibels:
r"""
Convert voltage standing wave ratio (VSWR) to return loss in decibels.
Parameters
----------
vswr : Quantity
VSWR value (>= 1). Use 1 for a perfect match (infinite return loss)
Returns
-------
Quantity
Return loss in decibels
Raises
------
ValueError
If vswr is less than 1
"""
if np.any(vswr < 1.0):
raise ValueError("VSWR must be >= 1.")
gamma = (vswr - 1) / (vswr + 1)
return (1 / np.abs(gamma) ** 2).to(u.dB(1))
[docs]
def safe_negate(quantity: Quantity) -> Quantity:
"""
Safely negate a dB or function unit quantity, preserving the unit.
Astropy does not allow direct negation of function units (like dB).
"""
return (-1 * quantity.value) * quantity.unit