Source code for nkdsl.dsl.operator

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Fluent operator builder, the primary entry point to the symbolic DSL.

The :class:`SymbolicDiscreteJaxOperator` class is the single entry point for constructing symbolic
quantum operators.

Usage
-----
::

    from nkdsl import SymbolicDiscreteJaxOperator
    from nkdsl.dsl import site, shift, swap, identity

    # diagonal operator
    N_e0 = (
        SymbolicDiscreteJaxOperator(hi, "N_e0", hermitian=True)
        .globally()
        .emit(identity(), matrix_element=my_sq_norm_expr)
        .build()
    )

    # single-site off-diagonal
    h_plus = (
        SymbolicDiscreteJaxOperator(hi, "h+")
        .for_each_site("e")
        .where(site("e") < cutoff)
        .emit(shift("e", +1))
        .build()
    )

    # hopping: compound update + matrix element from both DOFs
    hop = (
        SymbolicDiscreteJaxOperator(hi, "hopping")
        .for_each_pair("i", "j")
        .where(site("i") > 0)
        .emit(
            shift("i", -1).shift("j", +1),
            matrix_element=site("i").value * site("j").value,
        )
        .build()
    )

    # K-body: static triplet iterator
    vol = (
        SymbolicDiscreteJaxOperator(hi, "triplet_volume")
        .for_each(("e1", "e2", "e3"), over=triplet_index_sets)
        .emit(identity(), matrix_element=triple_product_expr)
        .build()
    )

    # multi-emission: two branches per iterator evaluation
    two_branch = (
        SymbolicDiscreteJaxOperator(hi, "two_branch")
        .for_each_site("i")
        .where(site("i").abs() < 2)
        .emit(shift("i", +1), matrix_element=+0.5)
        .emit(shift("i", -1), matrix_element=-0.5)
        .build()
    )

    # compile directly (skip explicit .build())
    compiled = SymbolicDiscreteJaxOperator(hi, "my_op").for_each_site("i").emit(shift("i", +1)).compile()

Iterator methods
----------------
Calling any ``for_each_*`` / ``globally`` method **seals** the current
in-progress term and begins a new one. Calling ``.where`` or ``.emit``
after is always associated with the most recent iterator call.

Multi-emission
--------------
Multiple ``.emit(...)`` calls on the same iterator scope produce multiple
output branches (``EmissionSpec`` entries) from a **single** iterator
evaluation. This avoids the overhead of iterating over sites twice and
keeps the semantic unit cohesive.

Connected-state dedup note
--------------------------
By default, if two terms (or two emissions within one term) produce the same
``x'``, those entries are merged and their matrix elements are summed. Final
zero-amplitude components are dropped before padding.

Set ``deduplicate_connected_components=False`` during ``.compile(...)`` to keep
raw per-branch connectivity output.
"""

from __future__ import annotations

from collections.abc import Sequence
from typing import Any
from typing import TYPE_CHECKING

import numpy as np

from netket.hilbert import DiscreteHilbert

from nkdsl.debug import event as debug_event
from nkdsl.dsl.emissions.defaults import ensure_default_emission_clause_registrations
from nkdsl.dsl.emissions.dispatch import apply_emission_clause
from nkdsl.dsl.emissions.registry import available_emission_clause_names
from nkdsl.dsl.emissions.registry import resolve_emission_clause
from nkdsl.dsl.emissions.types import EmissionClauseSpec
from nkdsl.dsl.iterators.dispatch import apply_iterator_clause
from nkdsl.dsl.iterators.defaults import ensure_default_iterator_clause_registrations
from nkdsl.dsl.iterators.registry import available_iterator_clause_names
from nkdsl.dsl.iterators.registry import resolve_iterator_clause
from nkdsl.dsl.predicates.dispatch import apply_predicate_clause
from nkdsl.dsl.predicates.defaults import ensure_default_predicate_clause_registrations
from nkdsl.dsl.predicates.registry import available_predicate_clause_names
from nkdsl.dsl.predicates.registry import resolve_predicate_clause
from nkdsl.dsl.rewrite import Update
from nkdsl.dsl.rewrite import (
    _IDENTITY as _IDENTITY_UPDATE,
)

from nkdsl.ir.expressions import AmplitudeExpr
from nkdsl.ir.expressions import (
    coerce_amplitude_expr,
)

from nkdsl.ir.predicates import PredicateExpr
from nkdsl.ir.predicates import coerce_predicate_expr

from nkdsl.ir.term import EmissionSpec
from nkdsl.ir.term import KBodyIteratorSpec
from nkdsl.ir.term import SymbolicIRTerm

from nkdsl.ir.update import UpdateProgram

if TYPE_CHECKING:
    from nkdsl import SymbolicOperator

ensure_default_iterator_clause_registrations()
ensure_default_predicate_clause_registrations()
ensure_default_emission_clause_registrations()


def _update_op_uses_shift_mod(op: Any) -> bool:
    if op.kind == "shift_mod_site":
        return True
    if op.kind == "cond_branch":
        then_ops = op.get("then_ops") or ()
        else_ops = op.get("else_ops") or ()
        return any(_update_op_uses_shift_mod(sub) for sub in then_ops) or any(
            _update_op_uses_shift_mod(sub) for sub in else_ops
        )
    return False


def _program_uses_shift_mod(program: UpdateProgram) -> bool:
    return any(_update_op_uses_shift_mod(op) for op in program.ops)


def _amplitude_uses_wrap_mod(expr: AmplitudeExpr) -> bool:
    if expr.op == "wrap_mod":
        return True
    return any(
        isinstance(arg, AmplitudeExpr) and _amplitude_uses_wrap_mod(arg) for arg in expr.args
    )


def _terms_use_shift_mod(terms: tuple[SymbolicIRTerm, ...]) -> bool:
    for term in terms:
        for em in term.effective_emissions:
            if _program_uses_shift_mod(em.update_program):
                return True
            if _amplitude_uses_wrap_mod(em.amplitude):
                return True
    return False


def _iter_amplitude_constants(expr: AmplitudeExpr):
    """Yields all constant payloads occurring in one amplitude expression tree."""
    if expr.op == "const":
        yield expr.args[0]
        return
    for arg in expr.args:
        if isinstance(arg, AmplitudeExpr):
            yield from _iter_amplitude_constants(arg)
        elif isinstance(arg, tuple):
            for item in arg:
                if isinstance(item, AmplitudeExpr):
                    yield from _iter_amplitude_constants(item)


def _resolve_matrix_element_dtype(
    base_dtype_str: str,
    terms: tuple[SymbolicIRTerm, ...],
) -> str:
    """
    Resolves the operator matrix-element dtype from user default + emissions.

    Complex constants in matrix-element expressions automatically promote
    the operator dtype using NumPy result-type promotion rules.
    """
    resolved = np.dtype(base_dtype_str)
    for term in terms:
        for em in term.effective_emissions:
            for const in _iter_amplitude_constants(em.amplitude):
                const_dtype = np.asarray(const).dtype
                if np.issubdtype(const_dtype, np.complexfloating):
                    resolved = np.result_type(resolved, const_dtype)
    return np.dtype(resolved).name


def _infer_shift_mod_spec_from_hilbert(hilbert: DiscreteHilbert) -> dict[str, Any]:
    """
    Infer uniform wrapped-shift semantics from hilbert.local_states.

    Current contract:
      - finite local_states must exist
      - they must be 1D
      - they must be contiguous unit-spaced integers
        e.g. [-m_max, ..., m_max] or [0, 1, 2, 3]

    This exactly matches the current modulo-wrap semantics used by the
    computational operators.
    """
    local_states = getattr(hilbert, "local_states", None)
    if local_states is None:
        raise ValueError(
            "shift_mod requires a discrete Hilbert with finite local_states. "
            "This Hilbert exposes local_states=None."
        )

    states = np.asarray(local_states)
    if states.ndim != 1 or states.size == 0:
        raise ValueError("shift_mod requires hilbert.local_states to be a non-empty 1D sequence.")

    # Require integer-valued local states
    states_i = states.astype(np.int64)
    if not np.array_equal(states, states_i):
        raise ValueError("shift_mod currently requires integer local_states.")

    # Require contiguous unit-spaced ascending values
    state_min = int(states_i[0])
    expected = np.arange(state_min, state_min + len(states_i), dtype=np.int64)
    if not np.array_equal(states_i, expected):
        raise ValueError(
            "shift_mod currently requires contiguous unit-spaced local_states, "
            "for example [-m_max, ..., m_max]. "
            f"Got {states_i.tolist()!r}."
        )

    return {
        "shift_mod_spec": {
            "version": "uniform_integer_wrap_v1",
            "state_min": state_min,
            "mod_span": int(len(states_i)),
            # included so the IR fingerprint/caches depend on the actual local basis
            "local_states": tuple(int(v) for v in states_i.tolist()),
        }
    }


def _coerce_update(u: Any) -> UpdateProgram:
    """Normalises Update or UpdateProgram to UpdateProgram."""
    if isinstance(u, Update):
        return u.to_program()
    if isinstance(u, UpdateProgram):
        return u
    raise TypeError(f"Expected Update or UpdateProgram; got {type(u).__name__!r}.")


def _coerce_amplitude(a: Any) -> AmplitudeExpr:
    """Resolves callables, bare numbers, or AmplitudeExpr nodes."""
    if callable(a):
        from nkdsl.dsl.context import (
            ExpressionContext,
        )

        return coerce_amplitude_expr(a(ExpressionContext()))
    return coerce_amplitude_expr(a)


#
#
# Internal term-in-progress


class _TermInProgress:
    """Mutable accumulator for one in-progress term definition."""

    __slots__ = (
        "_conditional_remaining",
        "_emissions",
        "_iterator",
        "_max_conn_size_hint",
        "_name",
        "_predicate",
    )

    def __init__(self, iterator: Any) -> None:
        self._iterator = iterator
        self._predicate: PredicateExpr = PredicateExpr.constant(True)
        self._emissions: list[EmissionSpec] = []
        self._name: "str | None" = None
        self._max_conn_size_hint: "int | None" = None
        self._conditional_remaining: PredicateExpr | None = None

    @property
    def name(self) -> "str | None":
        """Returns the optional user-assigned term name."""
        return self._name

    def set_name(self, name: str) -> None:
        """
        Sets the user-assigned term name.

        Args:
            name: Non-empty term name.
        """
        self._name = str(name)

    @property
    def max_conn_size_hint(self) -> "int | None":
        """Returns the optional explicit max-connection-size hint."""
        return self._max_conn_size_hint

    def set_max_conn_size_hint(self, hint: int) -> None:
        """
        Sets an explicit max-connection-size hint.

        Args:
            hint: Positive max-connection-size bound.
        """
        self._max_conn_size_hint = int(hint)

    @property
    def predicate(self) -> PredicateExpr:
        """Returns the current composed predicate expression."""
        return self._predicate

    @property
    def predicate_op(self) -> str:
        """Returns the top-level predicate operator name."""
        return self._predicate.op

    def set_predicate(self, pred: Any) -> None:
        """
        Replaces the current term predicate.

        Args:
            pred: Predicate expression or coercible value.
        """
        self._predicate = coerce_predicate_expr(pred)

    @property
    def emission_count(self) -> int:
        """Returns the current emission count for this term."""
        return len(self._emissions)

    @property
    def has_open_conditional_chain(self) -> bool:
        """Returns whether an if/elseif chain is currently open."""
        return self._conditional_remaining is not None

    def close_conditional_chain(self) -> None:
        """Closes any currently open conditional emission chain."""
        self._conditional_remaining = None

    def _require_conditional_chain(self, method: str) -> PredicateExpr:
        if self._conditional_remaining is None:
            raise ValueError(
                f".{method}() must follow .emit_if(...) or .emit_elseif(...) "
                "without intervening term modifiers."
            )
        return self._conditional_remaining

    def add_emission(
        self,
        update: Any,
        matrix_element: Any,
        tag: Any,
        *,
        predicate: Any = True,
        amplitude: Any | None = None,
    ) -> None:
        prog = _coerce_update(update)
        if amplitude is not None:
            matrix_element = amplitude
        amp = _coerce_amplitude(matrix_element)
        pred = coerce_predicate_expr(predicate)
        self._emissions.append(
            EmissionSpec(
                update_program=prog,
                amplitude=amp,
                branch_tag=tag,
                predicate=pred,
            )
        )
        debug_event(
            "registered term emission",
            scope="dsl",
            tag="DSL",
            update_kind_count=len(prog.ops),
            branch_tag=tag,
            emission_count=len(self._emissions),
            predicate_op=pred.op,
        )

    def add_conditional_if(
        self,
        predicate: Any,
        update: Any,
        matrix_element: Any,
        tag: Any,
        *,
        amplitude: Any | None = None,
    ) -> None:
        cond = coerce_predicate_expr(predicate)
        self.close_conditional_chain()
        self.add_emission(
            update=update,
            matrix_element=matrix_element,
            tag=tag,
            predicate=cond,
            amplitude=amplitude,
        )
        self._conditional_remaining = PredicateExpr.not_(cond)

    def add_conditional_elseif(
        self,
        predicate: Any,
        update: Any,
        matrix_element: Any,
        tag: Any,
        *,
        amplitude: Any | None = None,
    ) -> None:
        remaining = self._require_conditional_chain("emit_elseif")
        cond = coerce_predicate_expr(predicate)
        branch_pred = PredicateExpr.and_(remaining, cond)
        self.add_emission(
            update=update,
            matrix_element=matrix_element,
            tag=tag,
            predicate=branch_pred,
            amplitude=amplitude,
        )
        self._conditional_remaining = PredicateExpr.and_(remaining, PredicateExpr.not_(cond))

    def add_conditional_else(
        self,
        update: Any,
        matrix_element: Any,
        tag: Any,
        *,
        amplitude: Any | None = None,
    ) -> None:
        remaining = self._require_conditional_chain("emit_else")
        self.add_emission(
            update=update,
            matrix_element=matrix_element,
            tag=tag,
            predicate=remaining,
            amplitude=amplitude,
        )
        self.close_conditional_chain()

    def to_ir_term(self, auto_name: str) -> SymbolicIRTerm:
        self.close_conditional_chain()
        if not self._emissions:
            name = self.name if self.name is not None else auto_name
            raise ValueError(
                f"Term {name!r} has no emissions. " "Call .emit(...) before .build() / .compile()."
            )
        emissions_tuple = tuple(self._emissions)
        first = emissions_tuple[0]
        name = self.name if self.name is not None else auto_name

        # Auto-infer max_conn_size from iterator size x emission count when not user-set
        max_conn_size_hint = self.max_conn_size_hint
        if max_conn_size_hint is None and isinstance(self._iterator, KBodyIteratorSpec):
            max_conn_size_hint = len(self._iterator.index_sets) * len(emissions_tuple)

        return SymbolicIRTerm.create(
            name=name,
            iterator=self._iterator,
            predicate=self._predicate,
            update_program=first.update_program,
            amplitude=first.amplitude,
            branch_tag=first.branch_tag,
            emissions=emissions_tuple,
            max_conn_size_hint=max_conn_size_hint,
        )


#
#
#   Operator builder


[docs] class SymbolicDiscreteJaxOperator: """ Fluent builder for declarative symbolic quantum operators. The builder accumulates one or more *terms*. Each term consists of an *iterator* (which sites to visit), an optional *predicate* (which visits to activate), and one or more *emissions* (how to rewrite the configuration and what matrix element to assign per active visit). Calling any iterator method (``for_each_site``, ``for_each_pair``, ..., ``globally``) **seals** the previous term (if any) and begins a new one. ``.where`` and ``.emit`` always target the current open term. Args: hilbert: NetKet :class:`~netket.hilbert.DiscreteHilbert` space. name: Readable operator name (accessible as ``.name`` on the resulting :class:`~nkdsl.core.operator.SymbolicOperator`). dtype: Matrix-element dtype string (default ``"float64"``). hermitian: Whether to declare the operator Hermitian. """ __slots__ = ( "_completed_terms", "_current", "_dtype", "_hermitian", "_hilbert", "_name", )
[docs] def __init__( self, hilbert: DiscreteHilbert, name: str = "operator", *, dtype: str = "float64", hermitian: bool = False, ) -> None: name = str(name).strip() if not name: raise ValueError("Operator name must be a non-empty string.") self._hilbert = hilbert self._name = name self._dtype = str(dtype) self._hermitian = bool(hermitian) self._completed_terms: list[SymbolicIRTerm] = [] self._current: _TermInProgress | None = None debug_event( "created symbolic dsl builder", scope="dsl", tag="DSL", operator_name=self._name, dtype=self._dtype, hermitian=self._hermitian, hilbert_size=int(self._hilbert.size), )
@property def hilbert(self) -> DiscreteHilbert: """ Returns the Hilbert space associated with this builder. Returns: DiscreteHilbert: The builder Hilbert space. """ return self._hilbert def open_term(self, iterator: Any) -> "SymbolicDiscreteJaxOperator": """ Opens a new in-progress term with the provided iterator. Args: iterator: Iterator descriptor for the new term. Returns: SymbolicDiscreteJaxOperator: This builder for fluent chaining. """ return self._open_term(iterator) def append_predicate( self, predicate: PredicateExpr, *, method_name: str = "where", ) -> "SymbolicDiscreteJaxOperator": """ Appends one predicate to the current term. This is a public facade used by external predicate-clause modules to avoid accessing private builder methods directly. Args: predicate: Predicate expression to append. method_name: Name of the fluent method responsible for the append. Returns: SymbolicDiscreteJaxOperator: This builder for fluent chaining. """ return self._append_predicate(predicate, method_name=method_name) def append_emission_clause( self, spec: EmissionClauseSpec, *, method_name: str, ) -> "SymbolicDiscreteJaxOperator": """ Appends emission behavior described by one normalized emission clause spec. Args: spec: Normalized emission clause action. method_name: Name of the fluent method responsible for the append. Returns: SymbolicDiscreteJaxOperator: This builder for fluent chaining. """ return self._append_emission_clause(spec, method_name=method_name) # # # Internal def _seal_current(self) -> None: """Finalises the current in-progress term and appends it.""" if self._current is not None: term_name = str(len(self._completed_terms)) ir_term = self._current.to_ir_term(term_name) self._completed_terms.append(ir_term) self._current = None debug_event( "sealed dsl term", scope="dsl", tag="DSL", term_name=ir_term.name, max_conn_size_hint=ir_term.max_conn_size_hint, emission_count=len(ir_term.effective_emissions), ) def _open_term(self, iterator: Any) -> "SymbolicDiscreteJaxOperator": """Seals any open term and starts a new one with *iterator*.""" self._seal_current() self._current = _TermInProgress(iterator) debug_event( "opened dsl term", scope="dsl", tag="DSL", iterator_kind=getattr(iterator, "kind", type(iterator).__name__), labels=getattr(iterator, "labels", None), ) return self def _require_open(self, method: str) -> _TermInProgress: if self._current is None: raise ValueError( f".{method}() called before any iterator method " "(for_each_site / for_each_pair / for_each / globally). " "Declare an iterator first." ) return self._current # # # Iterator methods def globally(self) -> "SymbolicDiscreteJaxOperator": """ Sets a **global** iterator, one branch per configuration. Use this for diagonal operators (area, number, volume, ...) and for off-diagonal operators where the target sites are baked into the amplitude or update program via :func:`~nkdsl.ir.expressions.AmplitudeExpr.static_index`. Returns: This builder (for chaining). """ return apply_iterator_clause(self, "globally") def for_each_site(self, label: str = "i") -> "SymbolicDiscreteJaxOperator": """ Iterates over **all** sites ``0 ... hilbert.size-1``. The site index is bound to *label* in the evaluation environment: ``site(label).value`` -> ``x[site_index]`` and ``site(label).index`` -> the integer site index. Args: label: Iterator label string (default ``"i"``). Returns: This builder (for chaining). """ return apply_iterator_clause(self, "for_each_site", label) def for_each_pair( self, label_a: str = "i", label_b: str = "j", ) -> "SymbolicDiscreteJaxOperator": """ Iterates over all ordered pairs ``(i, j)`` with ``i, j ∈ [0, N)``. Includes diagonal pairs ``(i, i)``. To exclude them add a predicate ``.where(site(label_a).index != site(label_b).index)`` or use ``.for_each_distinct_pair()``. Args: label_a: Primary site label. label_b: Secondary site label. Returns: This builder (for chaining). """ return apply_iterator_clause(self, "for_each_pair", label_a, label_b) def for_each_distinct_pair( self, label_a: str = "i", label_b: str = "j", ) -> "SymbolicDiscreteJaxOperator": """ Iterates over all ordered pairs ``(i, j)`` with ``i, j ∈ [0, N)``. Excludes diagonal pairs ``(i, i)``. To include them, use ``.for_each_pair()``. Args: label_a: Primary site label. label_b: Secondary site label. Returns: This builder (for chaining). """ return apply_iterator_clause(self, "for_each_distinct_pair", label_a, label_b) def for_each_triplet( self, label_a: str, label_b: str, label_c: str, *, over: Sequence[tuple[int, int, int]], ) -> "SymbolicDiscreteJaxOperator": """ Iterates over a **static list of ordered triplets**. Args: label_a: First site label. label_b: Second site label. label_c: Third site label. over: Sequence of ``(i, j, k)`` integer index triplets. Returns: This builder (for chaining). """ return apply_iterator_clause( self, "for_each_triplet", label_a, label_b, label_c, over=over, ) def for_each_plaquette( self, label_a: str, label_b: str, label_c: str, label_d: str, *, over: Sequence[tuple[int, int, int, int]], ) -> "SymbolicDiscreteJaxOperator": """ Iterates over a **static list of ordered plaquettes** (4-body). Args: label_*: Site labels for the four corners. over: Sequence of ``(i, j, k, l)`` integer index 4-tuples. Returns: This builder (for chaining). """ return apply_iterator_clause( self, "for_each_plaquette", label_a, label_b, label_c, label_d, over=over, ) def for_each( self, labels: Sequence[str], *, over: Sequence[Sequence[int]], ) -> "SymbolicDiscreteJaxOperator": """ Iterates over an **arbitrary static list of K-tuples**. This is the most general iterator method. All other ``for_each_*`` methods are convenience wrappers around this one. Args: labels: Sequence of K label strings. over: Sequence of K-tuples of integer site indices. Must be non-empty; all tuples must have length ``len(labels)``. Returns: This builder (for chaining). Raises: ValueError: If *over* is empty or any tuple has the wrong length. Example:: # Graph-neighbourhood iterator from an adjacency list edges = [(src, dst) for src, nbrs in adj.items() for dst in nbrs] op = ( SymbolicDiscreteJaxOperator(hi, "nbr_hop") .for_each(("src", "dst"), over=edges) .where(site("src") > 0) .emit(shift("src", -1).shift("dst", +1)) .build() ) """ return apply_iterator_clause(self, "for_each", labels, over=over) # # # Term annotation def named(self, name: str) -> "SymbolicDiscreteJaxOperator": """ Assigns a readable name to the current term. By default terms are named by their zero-based index (``"0"``, ``"1"``, ...). Call ``.named(...)`` after an iterator method to override this with a descriptive label that appears in IR dumps and compiler diagnostics. Args: name: Non-empty string label for this term. Returns: This builder (for chaining). """ term = self._require_open("named") term.close_conditional_chain() name = str(name).strip() if not name: raise ValueError("Term name must be a non-empty string.") term.set_name(name) return self def max_conn_size(self, hint: int) -> "SymbolicDiscreteJaxOperator": """ Sets an explicit static max-connection-size hint for the current term. The hint is an upper bound on the total number of connected states this term produces per input configuration. When not set, the DSL infers it automatically as ``n_iter x n_emissions``, a correct but conservative bound. Provide a tighter value when the predicate or physics guarantees fewer active branches (e.g. a holonomy operator with a hard cutoff always emits exactly 1 state). The hint is used by the compiler's buffer pre-allocation pass. Args: hint: Positive integer upper bound. Returns: This builder (for chaining). """ term = self._require_open("max_conn_size") term.close_conditional_chain() hint = int(hint) if hint <= 0: raise ValueError(f"max_conn_size hint must be a positive integer; got {hint!r}.") term.set_max_conn_size_hint(hint) return self def fanout(self, hint: int) -> "SymbolicDiscreteJaxOperator": """Backward-compatible alias for :meth:`max_conn_size`.""" return self.max_conn_size(hint) def _append_predicate( self, predicate: PredicateExpr, *, method_name: str = "where", ) -> "SymbolicDiscreteJaxOperator": """Composes one predicate into the current term with logical AND.""" term = self._require_open(method_name) term.close_conditional_chain() existing = term.predicate if existing.op == "const" and bool(existing.args[0]): term.set_predicate(predicate) else: term.set_predicate(PredicateExpr.and_(existing, predicate)) debug_event( "updated term predicate", scope="dsl", tag="DSL", predicate_op=term.predicate_op, predicate_method=method_name, ) return self def _append_emission_clause( self, spec: EmissionClauseSpec, *, method_name: str, ) -> "SymbolicDiscreteJaxOperator": """Applies one emission-clause specification to the current term.""" term = self._require_open(method_name) if spec.mode == "emit": term.close_conditional_chain() term.add_emission( update=_IDENTITY_UPDATE if spec.update is None else spec.update, matrix_element=spec.matrix_element, tag=spec.tag, amplitude=spec.amplitude, ) elif spec.mode == "emit_if": term.add_conditional_if( predicate=spec.predicate, update=_IDENTITY_UPDATE if spec.update is None else spec.update, matrix_element=spec.matrix_element, tag=spec.tag, amplitude=spec.amplitude, ) elif spec.mode == "emit_elseif": term.add_conditional_elseif( predicate=spec.predicate, update=_IDENTITY_UPDATE if spec.update is None else spec.update, matrix_element=spec.matrix_element, tag=spec.tag, amplitude=spec.amplitude, ) elif spec.mode == "emit_else": term.add_conditional_else( update=_IDENTITY_UPDATE if spec.update is None else spec.update, matrix_element=spec.matrix_element, tag=spec.tag, amplitude=spec.amplitude, ) else: raise ValueError( f"Unsupported emission clause mode {spec.mode!r}. " "Expected one of: emit, emit_if, emit_elseif, emit_else." ) debug_event( "applied emission clause", scope="dsl", tag="DSL", clause_method=method_name, clause_mode=spec.mode, emission_count=term.emission_count, branch_tag=spec.tag, has_open_conditional_chain=term.has_open_conditional_chain, ) return self # # # Predicate def where(self, predicate: Any) -> "SymbolicDiscreteJaxOperator": """ Sets the **branch predicate** for the current term. The predicate is evaluated in the iterator environment (``x``, site labels). Only branches where the predicate is ``True`` emit connected states, the rest contribute zero matrix elements. Multiple ``.where`` calls on the same term compose with logical AND:: .where(site("i") > 0).where(site("j") < 2) # ↑ equivalent to .where((site("i") > 0) & (site("j") < 2)) Args: predicate: :class:`~nkdsl.ir.predicates.PredicateExpr` or any value coercible to one (e.g. ``site("i").value > 0``). Returns: This builder (for chaining). """ return apply_predicate_clause(self, "where", predicate) # # # Emission def emit( self, update: Any = None, *, matrix_element: Any = 1.0, amplitude: Any | None = None, tag: Any = None, ) -> "SymbolicDiscreteJaxOperator": """ Appends one **output branch** to the current term. Each call to ``.emit(...)`` on the same iterator scope adds one :class:`~nkdsl.ir.term.EmissionSpec` to the current term. Multiple emissions produce multiple connected states per iterator evaluation, e.g. raise *and* lower from the same site without splitting into two separate terms. Matrix-element semantics -------------------- The matrix-element expression is evaluated in the *source* configuration environment ``(x, site_labels)``. There is no access to ``x'`` inside matrix-element expressions: ``<x|O|x'>`` is computed from ``x``, not ``x'``. Args: update: Site-rewrite program describing ``x -> x'``. Accepts :class:`~nkdsl.dsl.rewrite.Update`, or :class:`~nkdsl.ir.update.UpdateProgram`. Pass ``None`` or :func:`~nkdsl.dsl.rewrite.identity` for diagonal (identity) updates. matrix_element: Matrix element: numeric constant, symbolic :class:`~nkdsl.ir.expressions.AmplitudeExpr`, or a callable ``(ExpressionContext) -> AmplitudeExpr``. amplitude: Deprecated alias of ``matrix_element``. tag: Optional diagnostic label for this emission branch. Returns: This builder (for chaining). """ return self.append_emission_clause( EmissionClauseSpec( mode="emit", update=update, matrix_element=matrix_element, amplitude=amplitude, tag=tag, ), method_name="emit", ) def emit_if( self, predicate: Any, update: Any = None, *, matrix_element: Any = 1.0, amplitude: Any | None = None, tag: Any = None, ) -> "SymbolicDiscreteJaxOperator": """ Appends the ``if`` branch of a conditional emission chain. The branch emits only when *predicate* evaluates to true. Subsequent ``.emit_elseif(...)`` and ``.emit_else(...)`` calls can refine the same chain. """ return apply_emission_clause( self, "emit_if", predicate, update, matrix_element=matrix_element, amplitude=amplitude, tag=tag, ) def emit_elseif( self, predicate: Any, update: Any = None, *, matrix_element: Any = 1.0, amplitude: Any | None = None, tag: Any = None, ) -> "SymbolicDiscreteJaxOperator": """ Appends an ``elseif`` branch to the current conditional emission chain. This method must directly follow ``emit_if(...)`` or another ``emit_elseif(...)`` on the same term. """ return apply_emission_clause( self, "emit_elseif", predicate, update, matrix_element=matrix_element, amplitude=amplitude, tag=tag, ) def emit_else( self, update: Any = None, *, matrix_element: Any = 1.0, amplitude: Any | None = None, tag: Any = None, ) -> "SymbolicDiscreteJaxOperator": """ Appends the ``else`` branch to the current conditional emission chain. This branch emits when all prior ``if`` / ``elseif`` predicates are false. """ return apply_emission_clause( self, "emit_else", update, matrix_element=matrix_element, amplitude=amplitude, tag=tag, ) # # # Finalisation def build(self) -> "SymbolicOperator": """ Seals all open terms and returns a :class:`~nkdsl.core.operator.SymbolicOperator`. Returns: :class:`~nkdsl.core.operator.SymbolicOperator` ready for compilation. Raises: ValueError: If no terms have been defined, or the current open term has no emissions. """ self._seal_current() if not self._completed_terms: raise ValueError( "Cannot build an operator with zero terms. " "Add at least one iterator + emit() block." ) debug_event( "building symbolic operator", scope="dsl", tag="DSL", operator_name=self._name, term_count=len(self._completed_terms), ) metadata: dict[str, Any] = {} if _terms_use_shift_mod(tuple(self._completed_terms)): metadata.update(_infer_shift_mod_spec_from_hilbert(self._hilbert)) debug_event( "inferred shift_mod metadata", scope="dsl", tag="DSL", metadata_keys=tuple(sorted(metadata)), ) resolved_dtype = _resolve_matrix_element_dtype( self._dtype, tuple(self._completed_terms), ) if resolved_dtype != self._dtype: debug_event( "promoted operator dtype from matrix-element constants", scope="dsl", tag="DSL", operator_name=self._name, requested_dtype=self._dtype, resolved_dtype=resolved_dtype, ) from nkdsl.core.operator import ( SymbolicOperator, ) op = SymbolicOperator( self._hilbert, self._name, tuple(self._completed_terms), dtype_str=resolved_dtype, is_hermitian=self._hermitian, metadata=metadata or None, ) debug_event( "built symbolic operator", scope="dsl", tag="DSL", operator_name=op.name, term_count=op.term_count, ) return op def compile( self, *, backend: str = "jax", operator_lowering: str = "netket_discrete_jax", deduplicate_connected_components: bool = True, cache: bool = True, ) -> Any: """ Convenience shortcut: ``.build().compile(...)``. Returns: Executable compiled operator instance. """ debug_event( "compiling directly from dsl builder", scope="dsl", tag="DSL", operator_name=self._name, backend=backend, operator_lowering=operator_lowering, deduplicate_connected_components=deduplicate_connected_components, cache=cache, ) return self.build().compile( backend=backend, operator_lowering=operator_lowering, deduplicate_connected_components=deduplicate_connected_components, cache=cache, ) def __getattr__(self, name: str) -> Any: """ Resolves dynamically-registered iterator/predicate/emission clause methods. This enables fluent user extensions such as ``builder.my_iterator(...)`` without adding concrete methods to this class. """ iterator_clause = resolve_iterator_clause(name) if iterator_clause is not None: def _bound_iterator(*args: Any, **kwargs: Any) -> "SymbolicDiscreteJaxOperator": return apply_iterator_clause(self, name, *args, **kwargs) _bound_iterator.__name__ = name _bound_iterator.__qualname__ = f"{type(self).__name__}.{name}" _bound_iterator.__doc__ = iterator_clause.__doc__ return _bound_iterator predicate_clause = resolve_predicate_clause(name) if predicate_clause is not None: def _bound_predicate(*args: Any, **kwargs: Any) -> "SymbolicDiscreteJaxOperator": return apply_predicate_clause(self, name, *args, **kwargs) _bound_predicate.__name__ = name _bound_predicate.__qualname__ = f"{type(self).__name__}.{name}" _bound_predicate.__doc__ = predicate_clause.__doc__ return _bound_predicate emission_clause = resolve_emission_clause(name) if emission_clause is not None: def _bound_emission(*args: Any, **kwargs: Any) -> "SymbolicDiscreteJaxOperator": return apply_emission_clause(self, name, *args, **kwargs) _bound_emission.__name__ = name _bound_emission.__qualname__ = f"{type(self).__name__}.{name}" _bound_emission.__doc__ = emission_clause.__doc__ return _bound_emission raise AttributeError(f"{type(self).__name__!s} object has no attribute {name!r}") def __dir__(self) -> list[str]: """ Includes registered clause names in interactive completion output. """ base = set(super().__dir__()) base.update(available_iterator_clause_names()) base.update(available_predicate_clause_names()) base.update(available_emission_clause_names()) return sorted(base) def __repr__(self) -> str: n_sealed = len(self._completed_terms) n_open = 1 if self._current is not None else 0 return ( f"{type(self).__name__}(" f"name={self._name!r}, " f"dtype={self._dtype!r}, " f"terms_sealed={n_sealed}, " f"term_open={bool(n_open)})" )
__all__ = [ "SymbolicDiscreteJaxOperator", ]