Source code for nkdsl.ir.expressions

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Typed amplitude-expression IR nodes for symbolic operators.

This module defines a compact, hashable expression tree used to represent
matrix-element amplitudes in declarative symbolic operators.
"""

from __future__ import annotations

import dataclasses
import numbers
from typing import Any

import numpy as np

_AMPLITUDE_OPS: frozenset[str] = frozenset(
    {
        "const",
        "symbol",
        "neg",
        "sqrt",
        "conj",
        "add",
        "sub",
        "mul",
        "div",
        "pow",  # base^exp, element-wise power
        "abs_",  # |operand|, absolute value
        "static_index",  # x[flat_index], reads source configuration
        "static_emitted_index",  # x'[flat_index], reads emitted/connected configuration
        "wrap_mod",  # wrap operand using Hilbert local_states modulo semantics
    }
)

_UNSET_SYMBOL_DEFAULT = object()
_SYMBOL_DECLARATION_KEYS: frozenset[str] = frozenset({"default", "doc", "dtype"})


def _freeze(v: Any) -> Any:
    """Recursively convert lists to tuples for hashability."""
    if isinstance(v, list):
        return tuple(_freeze(i) for i in v)
    return v


def _normalize_symbol_name(name: str) -> str:
    """Normalizes one symbol name and validates non-empty content."""
    normalized = str(name).strip()
    if not normalized:
        raise ValueError("Amplitude symbols must be non-empty strings.")
    return normalized


def _normalize_symbol_dtype(dtype: Any | None) -> str | None:
    """Normalizes one optional symbol dtype declaration.

    Args:
        dtype: Optional NumPy-compatible dtype specifier.

    Returns:
        Canonical dtype name, or ``None`` when no dtype is declared.
    """
    if dtype is None:
        return None
    try:
        return np.dtype(dtype).name
    except TypeError as exc:
        raise ValueError(f"Unsupported symbol dtype declaration: {dtype!r}.") from exc


def parse_symbol_declaration_args(args: tuple[Any, ...]) -> tuple[str, dict[str, Any]]:
    """Parses a symbol-expression payload into name + declaration map.

    Supported payload forms:
      1. ``(name,)``
      2. ``(name, declaration_tuple)``

    ``declaration_tuple`` must be a stable tuple of ``(key, value)`` pairs.

    Args:
        args: ``AmplitudeExpr.args`` payload for a ``symbol`` node.

    Returns:
        Tuple ``(name, declaration_map)``.
    """
    if not args:
        raise ValueError("Symbol expression payload must contain at least one argument.")

    name = _normalize_symbol_name(str(args[0]))
    if len(args) == 1:
        return name, {}

    if len(args) != 2:
        raise ValueError(
            "Symbol expression payload must have either one item (name) or two "
            "items (name + declaration tuple)."
        )

    raw_declaration = args[1]
    if not isinstance(raw_declaration, tuple):
        raise TypeError("Symbol declaration payload must be a tuple of (key, value) pairs.")

    declaration: dict[str, Any] = {}
    for entry in raw_declaration:
        if not isinstance(entry, tuple) or len(entry) != 2:
            raise TypeError("Each symbol declaration entry must be a two-item tuple.")
        key = str(entry[0]).strip()
        if not key:
            raise ValueError("Symbol declaration keys must be non-empty strings.")
        if key not in _SYMBOL_DECLARATION_KEYS:
            raise ValueError(
                f"Unsupported symbol declaration key {key!r}. "
                f"Allowed: {sorted(_SYMBOL_DECLARATION_KEYS)!r}."
            )
        if key in declaration:
            raise ValueError(f"Duplicate symbol declaration key {key!r}.")
        value = entry[1]
        if key == "doc":
            value = str(value).strip()
            if not value:
                continue
        elif key == "dtype":
            value = _normalize_symbol_dtype(value)
            if value is None:
                continue
        declaration[key] = value

    return name, declaration


[docs] @dataclasses.dataclass(frozen=True, repr=False) class AmplitudeExpr: """ Typed expression node for operator matrix elements. Attributes: op: Expression operation name. args: Ordered operation arguments (frozen tuple). """ op: str args: tuple = dataclasses.field(default_factory=tuple) def __post_init__(self) -> None: if self.op not in _AMPLITUDE_OPS: raise ValueError( f"Unsupported amplitude-expression op: {self.op!r}. " f"Allowed: {sorted(_AMPLITUDE_OPS)}." ) @classmethod def constant(cls, value: complex | float | int) -> "AmplitudeExpr": """Builds a constant-value expression node.""" if isinstance(value, bool): raise TypeError( "Boolean values are not valid amplitude constants. " "Use numeric literals instead." ) return cls(op="const", args=(_freeze(value),)) @classmethod def symbol( cls, name: str, *, default: Any = _UNSET_SYMBOL_DEFAULT, doc: str = "", dtype: Any | None = None, ) -> "AmplitudeExpr": """Builds a symbol-reference expression node. Args: name: Symbol name. default: Optional default value used when the symbol is not supplied in the runtime evaluation environment. doc: Optional descriptive note for tooling and readability. dtype: Optional declared dtype for this symbol. If omitted and ``default`` is provided, dtype is inferred from ``default``. Raises: TypeError: If a provided ``default`` cannot be converted to the declared ``dtype``. """ normalized_name = _normalize_symbol_name(name) normalized_dtype = _normalize_symbol_dtype(dtype) declaration: list[tuple[str, Any]] = [] if default is not _UNSET_SYMBOL_DEFAULT: default_value = _freeze(default) inferred_dtype = np.asarray(default).dtype.name if normalized_dtype is None: normalized_dtype = inferred_dtype else: try: np.asarray(default, dtype=np.dtype(normalized_dtype)) except (TypeError, ValueError) as exc: raise TypeError( f"Default value for symbol {normalized_name!r} cannot be converted " f"to declared dtype {normalized_dtype!r}." ) from exc declaration.append(("default", default_value)) normalized_doc = str(doc).strip() if normalized_doc: declaration.append(("doc", normalized_doc)) if normalized_dtype is not None: declaration.append(("dtype", normalized_dtype)) if declaration: return cls(op="symbol", args=(normalized_name, tuple(declaration))) return cls(op="symbol", args=(normalized_name,)) def neg(self) -> "AmplitudeExpr": """Builds a unary negation expression node.""" return AmplitudeExpr(op="neg", args=(coerce_amplitude_expr(self),)) def sqrt(self) -> "AmplitudeExpr": """Builds a square-root expression node.""" return AmplitudeExpr(op="sqrt", args=(coerce_amplitude_expr(self),)) def conj(self) -> "AmplitudeExpr": """Builds a complex-conjugate expression node.""" return AmplitudeExpr(op="conj", args=(coerce_amplitude_expr(self),)) def add(self, other: Any) -> "AmplitudeExpr": """Builds an addition expression node.""" return AmplitudeExpr( op="add", args=(coerce_amplitude_expr(self), coerce_amplitude_expr(other)), ) def sub(self, other: Any) -> "AmplitudeExpr": """Builds a subtraction expression node.""" return AmplitudeExpr( op="sub", args=(coerce_amplitude_expr(self), coerce_amplitude_expr(other)), ) def mul(self, other: Any) -> "AmplitudeExpr": """Builds a multiplication expression node.""" return AmplitudeExpr( op="mul", args=(coerce_amplitude_expr(self), coerce_amplitude_expr(other)), ) def div(self, other: Any) -> "AmplitudeExpr": """Builds a division expression node.""" return AmplitudeExpr( op="div", args=(coerce_amplitude_expr(self), coerce_amplitude_expr(other)), ) def pow(self, exponent: Any) -> "AmplitudeExpr": """Builds a power expression node (base ** exponent).""" return AmplitudeExpr( op="pow", args=(coerce_amplitude_expr(self), coerce_amplitude_expr(exponent)), ) def abs_(self) -> "AmplitudeExpr": """Builds an absolute-value expression node (|operand|).""" return AmplitudeExpr(op="abs_", args=(coerce_amplitude_expr(self),)) @classmethod def static_index(cls, flat_index: int) -> "AmplitudeExpr": """ Builds a static-index read node (reads x[flat_index] at eval time). Unlike :meth:`symbol`, which is bound to a site-iterator label, a static-index node reads the full input configuration ``x`` at a compile-time-constant integer offset. This is the primary mechanism for accessing *structured* Hilbert-space layouts (e.g. U(1)^G gauge theories where ``x[g * n_edges + e]`` holds the charge for gauge copy ``g`` at edge ``e``). The flat index is stored as a plain ``int`` in ``args[0]``. The JAX lowerer resolves it by reading ``env["__x__"][flat_index]``. Args: flat_index: Non-negative integer flat index into the configuration array ``x``. """ idx = int(flat_index) if idx < 0: raise ValueError(f"static_index requires a non-negative integer; got {flat_index!r}.") return cls(op="static_index", args=(idx,)) @classmethod def static_emitted_index(cls, flat_index: int) -> "AmplitudeExpr": """Builds a static-index read node for the emitted/connected state x'[flat_index].""" idx = int(flat_index) if idx < 0: raise ValueError( f"static_emitted_index requires a non-negative integer; got {flat_index!r}." ) return cls(op="static_emitted_index", args=(idx,)) def wrap_mod(self) -> "AmplitudeExpr": """Builds a Hilbert-aware modulo-wrap node.""" return AmplitudeExpr(op="wrap_mod", args=(coerce_amplitude_expr(self),)) def __add__(self, other: Any) -> "AmplitudeExpr": return self.add(other) def __radd__(self, other: Any) -> "AmplitudeExpr": return AmplitudeExpr.add(other, self) def __sub__(self, other: Any) -> "AmplitudeExpr": return self.sub(other) def __rsub__(self, other: Any) -> "AmplitudeExpr": return AmplitudeExpr.sub(other, self) def __mul__(self, other: Any) -> "AmplitudeExpr": return self.mul(other) def __rmul__(self, other: Any) -> "AmplitudeExpr": return AmplitudeExpr.mul(other, self) def __truediv__(self, other: Any) -> "AmplitudeExpr": return self.div(other) def __rtruediv__(self, other: Any) -> "AmplitudeExpr": return AmplitudeExpr.div(other, self) def __neg__(self) -> "AmplitudeExpr": return self.neg() # # # Comparison operators yield PredicateExpr (imported lazily) def __lt__(self, other: Any) -> Any: from .predicates import PredicateExpr return PredicateExpr.lt(self, other) def __le__(self, other: Any) -> Any: from .predicates import PredicateExpr return PredicateExpr.le(self, other) def __gt__(self, other: Any) -> Any: from .predicates import PredicateExpr return PredicateExpr.gt(self, other) def __ge__(self, other: Any) -> Any: from .predicates import PredicateExpr return PredicateExpr.ge(self, other) def __str__(self) -> str: return _render_amplitude(self) def __repr__(self) -> str: return f"AmplitudeExpr(op={self.op!r}, args={self.args!r})"
def _render_amplitude(expr: "AmplitudeExpr") -> str: """Renders an AmplitudeExpr as a readable infix string.""" op = expr.op args = expr.args if op == "const": v = args[0] if isinstance(v, complex): if v.imag == 0.0: v = v.real else: return repr(v) if isinstance(v, float) and v == int(v) and abs(v) < 1e15: return str(int(v)) return repr(v) if op == "symbol": name, _declaration = parse_symbol_declaration_args(args) parts = name.split(":") if len(parts) == 3: ns, label, field = parts if ns == "site": if field == "value": return f"x[{label}]" if field == "index": return label return f"x[{label}].{field}" if ns == "emit": if field == "value": return f"x'[{label}]" if field == "index": return label return f"x'[{label}].{field}" return f"%{name}" if op == "neg": inner = _render_amplitude(args[0]) return f"-{inner}" if op == "sqrt": return f"sqrt({_render_amplitude(args[0])})" if op == "conj": return f"conj({_render_amplitude(args[0])})" if op == "abs_": return f"|{_render_amplitude(args[0])}|" if op == "wrap_mod": return f"wrap({_render_amplitude(args[0])})" if op == "static_index": return f"x[{args[0]}]" if op == "static_emitted_index": return f"x'[{args[0]}]" if op == "add": return f"({_render_amplitude(args[0])} + {_render_amplitude(args[1])})" if op == "sub": return f"({_render_amplitude(args[0])} - {_render_amplitude(args[1])})" if op == "mul": return f"({_render_amplitude(args[0])} * {_render_amplitude(args[1])})" if op == "div": return f"({_render_amplitude(args[0])} / {_render_amplitude(args[1])})" if op == "pow": return f"({_render_amplitude(args[0])}^{_render_amplitude(args[1])})" # Fallback for unknown ops arg_strs = ", ".join( _render_amplitude(a) if isinstance(a, AmplitudeExpr) else repr(a) for a in args ) return f"{op}({arg_strs})" def _collect_free_symbols(expr: "AmplitudeExpr", result: "set[str]") -> None: """Recursively collects free (non-iterator-bound) symbol names from an AmplitudeExpr.""" if expr.op == "symbol": name, declaration = parse_symbol_declaration_args(expr.args) parts = name.split(":") # Bound symbols follow "namespace:label:field" — site:i:value, emit:i:index, etc. if not (len(parts) == 3 and parts[0] in ("site", "emit")) and "default" not in declaration: result.add(name) return for arg in expr.args: if isinstance(arg, AmplitudeExpr): _collect_free_symbols(arg, result) elif isinstance(arg, tuple): for item in arg: if isinstance(item, AmplitudeExpr): _collect_free_symbols(item, result) def coerce_amplitude_expr(value: Any) -> AmplitudeExpr: """ Coerces user values into typed amplitude-expression nodes. Args: value: Input expression value. Returns: Typed amplitude expression. Raises: TypeError: If ``value`` cannot be converted. """ if isinstance(value, AmplitudeExpr): return value if isinstance(value, (bool, np.bool_)): raise TypeError( "Boolean values are not valid amplitude constants. " "Use numeric literals instead." ) if isinstance(value, numbers.Number): return AmplitudeExpr.constant(value) if isinstance(value, str): return AmplitudeExpr.symbol(value) raise TypeError( f"Cannot coerce {type(value)!r} into an AmplitudeExpr. " "Use numeric constants, symbol strings, or AmplitudeExpr values." ) __all__ = [ "AmplitudeExpr", "coerce_amplitude_expr", "_collect_free_symbols", "_render_amplitude", "parse_symbol_declaration_args", ]