Source code for nkdsl.ir.term

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Typed declarative symbolic operator-term IR structures."""

from __future__ import annotations

import dataclasses
from typing import Any

from .expressions import AmplitudeExpr
from .expressions import coerce_amplitude_expr

from .predicates import PredicateExpr
from .predicates import coerce_predicate_expr

from .update import UpdateProgram


[docs] @dataclasses.dataclass(frozen=True, repr=False) class KBodyIteratorSpec: """ Static K-body iterator over a pre-computed list of site-index tuples. This iterator evaluates a term kernel once per entry in *index_sets*. Each entry is a K-tuple of integer site indices that are bound to the corresponding element of *labels* inside the evaluation environment. For a single-site iterator over all N sites, use ``KBodyIteratorSpec(labels=("i",), index_sets=tuple((k,) for k in range(N)))``. For a static triplet iterator, provide the explicit list of ``(e1, e2, e3)`` triplets. For a global (one-branch) term, use ``KBodyIteratorSpec(labels=(), index_sets=((),))``. Attributes: labels: Ordered tuple of K label strings bound per iteration. index_sets: M-tuple of K-tuples of int site indices. """ labels: tuple # tuple[str, ...] index_sets: tuple # tuple[tuple[int, ...], ...] @property def kind(self) -> str: return "kbody" @property def label_a(self) -> str: return self.labels[0] if self.labels else "" @property def label_b(self) -> str | None: return self.labels[1] if len(self.labels) > 1 else None def _format_iterate_line(self, max_shown: int = 3) -> str: """Returns a compact readable description of the iteration domain.""" if not self.labels: return "globally" n = len(self.index_sets) if len(self.labels) == 1: labels_str = f"({self.labels[0]},)" else: labels_str = f"({', '.join(self.labels)})" shown = self.index_sets[:max_shown] shown_str = ", ".join(str(t) for t in shown) if n > max_shown: return f"for {labels_str} in [{shown_str}, ... +{n - max_shown} more]" return f"for {labels_str} in [{shown_str}]" def format_iterate_line(self, max_shown: int = 3) -> str: """ Returns a compact readable description of the iteration domain. Args: max_shown: Maximum number of index rows to render explicitly. Returns: str: Human-readable iterator summary. """ return self._format_iterate_line(max_shown=max_shown) def __repr__(self) -> str: return f"KBodyIteratorSpec(labels={self.labels!r}, " f"n_index_sets={len(self.index_sets)})"
[docs] @dataclasses.dataclass(frozen=True, repr=False) class EmissionSpec: """ One output branch (connected state + matrix element) of a term. A single iterator evaluation can produce multiple branches, one per ``EmissionSpec`` in the parent term's ``emissions`` tuple. This allows a plaquette term, for example, to emit both ``+`` and ``-`` connected states from the same site-tuple without splitting into two separate terms. Attributes: update_program: Site-update program mapping ``x -> x'``. amplitude: Matrix-element expression evaluated in the source environment. branch_tag: Optional diagnostic tag for this emission slot. predicate: Optional emission-local branch predicate. """ update_program: UpdateProgram amplitude: AmplitudeExpr branch_tag: Any = None predicate: PredicateExpr = dataclasses.field( default_factory=lambda: PredicateExpr.constant(True) ) @property def matrix_element(self) -> AmplitudeExpr: """Preferred alias for ``amplitude``.""" return self.amplitude
[docs] @dataclasses.dataclass(frozen=True, repr=False) class SymbolicIRTerm: """ One primitive declarative symbolic operator term. Attributes: name: Term name. iterator: Iterator descriptor (``KBodyIteratorSpec``). predicate: Branch-selection predicate. update_program: Matrix-element update program. amplitude: Matrix-element expression. branch_tag: Optional branch tag for diagnostics. metadata: Optional stable term metadata tuple. max_conn_size_hint: Optional static upper-bound hint on the number of connected states this term produces per input configuration. emissions: Optional multi-emission tuple that, when present, supersedes *update_program* and *amplitude*. Each entry is an :class:`EmissionSpec` representing one output branch per iterator evaluation. """ name: str iterator: Any # KBodyIteratorSpec predicate: PredicateExpr update_program: UpdateProgram amplitude: AmplitudeExpr branch_tag: Any = None metadata: tuple = dataclasses.field(default_factory=tuple) max_conn_size_hint: int | None = None emissions: tuple | None = None # tuple[EmissionSpec, ...] | None def __post_init__(self) -> None: if not self.name.strip(): raise ValueError("Symbolic IR term name must be a non-empty string.") if self.max_conn_size_hint is not None and self.max_conn_size_hint <= 0: raise ValueError( f"max_conn_size_hint must be a positive integer when provided; " f"received {self.max_conn_size_hint!r}." ) @property def effective_emissions(self) -> tuple: """ Returns the active emission list for this term. When *emissions* is set, returns it directly. Otherwise wraps the legacy *update_program* / *amplitude* pair in a single-element tuple. """ if self.emissions is not None: return self.emissions return ( EmissionSpec( update_program=self.update_program, amplitude=self.amplitude, branch_tag=self.branch_tag, ), ) @classmethod def create( cls, *, name: str, iterator: Any, predicate: Any, update_program: UpdateProgram, amplitude: Any, branch_tag: Any = None, metadata: dict[str, Any] | None = None, max_conn_size_hint: int | None = None, fanout_hint: int | None = None, emissions: tuple | None = None, ) -> "SymbolicIRTerm": """Builds a term from user-friendly values.""" meta_tuple: tuple if metadata is None: meta_tuple = () else: meta_tuple = tuple(sorted(metadata.items())) if max_conn_size_hint is None: max_conn_size_hint = fanout_hint return cls( name=str(name), iterator=iterator, predicate=coerce_predicate_expr(predicate), update_program=update_program, amplitude=coerce_amplitude_expr(amplitude), branch_tag=branch_tag, metadata=meta_tuple, max_conn_size_hint=max_conn_size_hint, emissions=emissions, ) @property def fanout_hint(self) -> int | None: """Backward-compatible alias for ``max_conn_size_hint``.""" return self.max_conn_size_hint def metadata_dict(self) -> dict[str, Any]: """Returns metadata in dictionary form.""" return dict(self.metadata) @property def free_symbols(self) -> frozenset: """ Returns the set of free (non-iterator-bound) symbol names in this term. Free symbols are those not bound by any site/emit iterator label, i.e. named parameters such as ``symbol("kappa")`` that remain unresolved. Symbols declared with ``default=...`` are considered resolved and are omitted. """ from .expressions import _collect_free_symbols from .predicates import _collect_free_symbols_pred from .update import _collect_free_symbols_from_ops result: set = set() _collect_free_symbols_pred(self.predicate, result) for em in self.effective_emissions: _collect_free_symbols_pred(em.predicate, result) _collect_free_symbols(em.amplitude, result) _collect_free_symbols_from_ops(em.update_program.ops, result) return frozenset(result) def _to_ir_lines(self, idx: "int | None" = None, indent: str = "") -> "list[str]": """Formats this term as indented IR lines suitable for embedding in an operator dump.""" name_part = f"#{idx} " if idx is not None else "" it = self.iterator it_kind = getattr(it, "kind", "unknown") n_iter = len(it.index_sets) if hasattr(it, "index_sets") else "?" max_conn = self.max_conn_size_hint if self.max_conn_size_hint is not None else "?" if hasattr(it, "format_iterate_line"): iter_line = it.format_iterate_line() else: iter_line = repr(it) inner = indent + " " lines = [ f'{indent}term {name_part}"{self.name}" [{it_kind}, n_iter={n_iter}, max_conn_size={max_conn}] {{' ] lines.append(f"{inner}iterate: {iter_line}") lines.append(f"{inner}where: {self.predicate}") for i, em in enumerate(self.effective_emissions): tag_str = f" [tag={em.branch_tag!r}]" if em.branch_tag is not None else "" lines.append(f"{inner}emit #{i}{tag_str}:") if not (em.predicate.op == "const" and bool(em.predicate.args[0])): lines.append(f"{inner} where: {em.predicate}") lines.append(f"{inner} update: {em.update_program}") lines.append(f"{inner} amplitude: {em.amplitude}") lines.append(f"{indent}}}") return lines def to_ir_lines(self, idx: "int | None" = None, indent: str = "") -> "list[str]": """ Formats this term as indented IR lines for operator dumps. Args: idx: Optional term index to include in header formatting. indent: Prefix indentation applied to all returned lines. Returns: list[str]: Rendered IR lines for this term. """ return self._to_ir_lines(idx=idx, indent=indent) def __str__(self) -> str: return "\n".join(self._to_ir_lines(idx=None, indent="")) def __repr__(self) -> str: return ( f"SymbolicIRTerm(" f"name={self.name!r}, " f"iterator={self.iterator!r}, " f"max_conn_size_hint={self.max_conn_size_hint!r})" )
def _scale_ir_term(term: SymbolicIRTerm, scale_expr: AmplitudeExpr) -> SymbolicIRTerm: """ Returns a new SymbolicIRTerm with all emission amplitudes multiplied by *scale_expr*. Handles both the multi-emission path (``term.emissions is not None``) and the single-emission path. """ import dataclasses new_amp = AmplitudeExpr.mul(scale_expr, term.amplitude) new_emissions: "tuple | None" = None if term.emissions is not None: new_emissions = tuple( dataclasses.replace(em, amplitude=AmplitudeExpr.mul(scale_expr, em.amplitude)) for em in term.emissions ) return dataclasses.replace(term, amplitude=new_amp, emissions=new_emissions) __all__ = [ "EmissionSpec", "KBodyIteratorSpec", "SymbolicIRTerm", "_scale_ir_term", ]