Source code for nkdsl.core.sum

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Additive composition of symbolic operators."""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

import numpy as np

from netket.hilbert import DiscreteHilbert

from nkdsl.core.base import AbstractSymbolicOperator
from nkdsl.ir.program import SymbolicOperatorIR


def _term_dtype_name(term: AbstractSymbolicOperator) -> str:
    """
    Resolves one term's dtype name using public APIs with compatibility fallback.

    Args:
        term: Symbolic operator term.

    Returns:
        str: Normalized dtype name for the term.

    Raises:
        AttributeError: If no dtype-like attribute is available.
    """
    if hasattr(term, "dtype_str"):
        return np.dtype(term.dtype_str).name
    if hasattr(term, "dtype"):
        return np.dtype(term.dtype).name
    legacy = getattr(term, "_dtype_val", None)
    if legacy is not None:
        return np.dtype(legacy).name
    raise AttributeError(f"Cannot resolve dtype for term of type {type(term).__name__!r}.")


def _term_display_name(term: AbstractSymbolicOperator) -> str:
    """
    Resolves one term's display name with compatibility fallback.

    Args:
        term: Symbolic operator term.

    Returns:
        str: Best-effort human-readable term name.
    """
    if hasattr(term, "name"):
        return str(term.name)
    if hasattr(term, "operator_name"):
        return str(term.operator_name)
    legacy = getattr(term, "_name_val", None)
    if legacy is not None:
        return str(legacy)
    return type(term).__name__


def _flatten_terms(
    terms: Sequence[AbstractSymbolicOperator],
) -> tuple[AbstractSymbolicOperator, ...]:
    """Flattens nested symbolic operator sums while preserving order."""
    flat: list[AbstractSymbolicOperator] = []
    for t in terms:
        if isinstance(t, SymbolicOperatorSum):
            flat.extend(t.terms)
        else:
            flat.append(t)
    return tuple(flat)


def _resolve_dtype(terms: Sequence[AbstractSymbolicOperator], explicit: str | None) -> str:
    """Resolves a common dtype string for an additive composition."""
    if explicit is not None:
        return np.dtype(explicit).name
    if not terms:
        return np.dtype("complex64").name
    resolved = np.dtype(_term_dtype_name(terms[0]))
    for term in terms[1:]:
        resolved = np.result_type(resolved, np.dtype(_term_dtype_name(term)))
    return np.dtype(resolved).name


[docs] class SymbolicOperatorSum(AbstractSymbolicOperator): """ Additive composition of multiple symbolic operators sharing one Hilbert space. ``SymbolicOperatorSum`` is the canonical Hamiltonian-style container for DSL-defined operators. It preserves term ordering, flattens nested sums, and aggregates max-connection-size bounds across all contained terms. Args: hilbert: Shared Hilbert space. terms: Sequence of symbolic operator terms. name: Optional user-facing operator name. dtype_str: Optional explicit dtype override. is_hermitian: Optional Hermiticity override (defaults to ``True`` iff all contained terms are Hermitian). metadata: Optional metadata dictionary. """ __slots__ = ("_terms",)
[docs] def __init__( self, hilbert: DiscreteHilbert, terms: Sequence[AbstractSymbolicOperator], *, name: str | None = None, dtype_str: str | None = None, is_hermitian: bool | None = None, metadata: dict[str, Any] | None = None, ) -> None: flattened = _flatten_terms(terms) if not flattened: raise ValueError("SymbolicOperatorSum requires at least one term.") for t in flattened: if t.hilbert != hilbert: raise ValueError( f"All terms in SymbolicOperatorSum must share the same " f"Hilbert space. Term {_term_display_name(t)!r} has a different hilbert." ) resolved_name = str(name).strip() if name and str(name).strip() else "symbolic_sum" resolved_dtype = _resolve_dtype(flattened, dtype_str) resolved_hermitian = ( bool(is_hermitian) if is_hermitian is not None else all(t.is_hermitian for t in flattened) ) super().__init__( hilbert, name=resolved_name, dtype_str=resolved_dtype, is_hermitian=resolved_hermitian, metadata=metadata, ) self._terms: tuple[AbstractSymbolicOperator, ...] = flattened
@property def terms(self) -> tuple[AbstractSymbolicOperator, ...]: """Returns contained additive terms in declaration order.""" return self._terms @property def free_symbols(self) -> frozenset: """Returns the union of free symbol names across all contained terms.""" result: set = set() for t in self._terms: if hasattr(t, "free_symbols"): result |= t.free_symbols return frozenset(result) def __len__(self) -> int: return len(self._terms) def __iter__(self): return iter(self._terms) def _apply_scalar(self, scalar: "int | float | complex") -> "SymbolicOperatorSum": from nkdsl.core.operator import ( _scalar_str, _promote_dtype_for_scalar, ) new_terms = tuple(t.apply_scalar(scalar) for t in self._terms) is_hermitian = self.is_hermitian and not isinstance(scalar, complex) new_dtype = _promote_dtype_for_scalar(self.dtype_str, scalar) return SymbolicOperatorSum( self.hilbert, new_terms, name=f"({_scalar_str(scalar)} * {self.name})", dtype_str=new_dtype, is_hermitian=is_hermitian, ) def to_ir(self) -> SymbolicOperatorIR: """ Builds one aggregate IR from all contained terms. All child IRs must be in ``symbolic`` mode. Returns: Aggregate symbolic operator IR. Raises: ValueError: If terms cannot be aggregated into one IR. """ child_irs = tuple(t.to_ir() for t in self._terms) modes = {ir.mode for ir in child_irs} if modes != {"symbolic"}: raise ValueError( f"Cannot aggregate term IRs with mixed modes: {modes!r}. " "All terms must be in 'symbolic' mode." ) combined_terms = tuple(term for child_ir in child_irs for term in child_ir.terms) fingerprints = tuple(ir.static_fingerprint() for ir in child_irs) meta = dict(self.metadata) meta["child_ir_fingerprints"] = fingerprints return SymbolicOperatorIR.from_terms( operator_name=self.name, hilbert_size=int(self.hilbert.size), dtype_str=self.dtype_str, is_hermitian=self.is_hermitian, terms=combined_terms, metadata=meta, ) def estimate_max_conn_size(self) -> int: """ Returns the aggregate static max-connection bound across all terms. Returns: Sum of per-term max-connection upper bounds. """ total = 0 for t in self._terms: if hasattr(t, "estimate_max_conn_size"): total += int(t.estimate_max_conn_size()) else: # Fallback: assume hilbert.size per unknown term total += int(self.hilbert.size) return max(1, total) def __repr__(self) -> str: return ( f"SymbolicOperatorSum(" f"name={self.name!r}, " f"term_count={len(self._terms)}, " f"dtype={self.dtype_str!r}, " f"is_hermitian={self.is_hermitian})" )
__all__ = ["SymbolicOperatorSum"]