Source code for nkdsl.dsl.rewrite

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""
Fluent immutable update-program builder for the symbolic operator DSL.

The :class:`Update` class is designed for ergonomic construction of
site-update programs: every method is chainable and returns a *new*
immutable :class:`Update` instance.

Module-level factory functions (``shift``, ``hop``, ``write``, ``swap``, ``permute``,
``affine``, ``scatter``, ``identity``) serve as zero-boilerplate entry points,
there is no need to first construct an empty ``Update()`` before chaining.

Examples::

    from nkdsl.dsl import shift, hop, write, swap, permute, scatter

    # Single operation, direct factory call
    update = shift("i", +1)

    # Compound update, chain freely
    update = shift("i", -1).shift("j", +1)  # hopping
    update = hop("i", "j")  # same hopping move
    update = swap("i", "j").write("k", 0)  # swap then zero
    update = permute("i", "j", "k")  # cyclic rotation
    update = affine("i", scale=2, bias=-1)  # x[i] = 2*x[i] - 1

    # Bulk scatter to static flat indices
    update = scatter([0, 5, 10], [val_0, val_5, val_10])

    # Identity (no-op), for diagonal operators
    update = identity()

    # Conditional update
    update = Update.cond(
        site("i") > 0,
        if_true=shift("i", -1),
        if_false=write("i", 0),
    )

Amplitude-side semantics
-------------------------
Amplitude expressions are evaluated after the emitted/connected state ``x'``
has been constructed for the current branch.

By default:
- ``site("i").value`` refers to the source configuration ``x[i]``.
- ``emitted("i").value`` refers to the connected/emitted configuration ``x'[i]``.

This allows matrix elements to depend on either the source state, the emitted
state, or both. In addition, ``wrap_mod(expr)`` applies the same Hilbert-aware
modulo-wrap semantics used by ``shift_mod(...)``.

Connected-state dedup semantics
-------------------------------
The padded connected-states representation returned by ``get_conn_padded`` is
deduplicated by default. If two terms (or two emissions within one term)
produce the same ``x'``, they are merged and their matrix elements are summed.
Zero-amplitude components are dropped before padding.

Set ``deduplicate_connected_components=False`` during ``.compile(...)`` to keep
raw per-branch output when needed.
"""

from __future__ import annotations

from typing import Any

from nkdsl.dsl.selectors import SiteSelector
from nkdsl.ir.expressions import (
    AmplitudeExpr,
    coerce_amplitude_expr,
)
from nkdsl.ir.predicates import (
    coerce_predicate_expr,
)
from nkdsl.ir.update import UpdateOp, UpdateProgram

#
#
#   Internal helpers


def _site_ref_to_expr(ref: Any) -> AmplitudeExpr:
    """
    Converts a site reference to an AmplitudeExpr for the site *index*.

    Accepts:
        - ``str`` label  -> ``symbol("site:<label>:index")``
        - :class:`SiteSelector` -> ``selector.index``
        - ``int``         -> ``constant(float(idx))``
        - :class:`AmplitudeExpr` -> used as-is
    """
    if isinstance(ref, str):
        return AmplitudeExpr.symbol(f"site:{ref}:index")
    if isinstance(ref, SiteSelector):
        return ref.as_site_ref()
    if isinstance(ref, int):
        return AmplitudeExpr.constant(float(ref))
    if isinstance(ref, AmplitudeExpr):
        return ref
    raise TypeError(
        f"site reference must be str, SiteSelector, int, or AmplitudeExpr; "
        f"got {type(ref).__name__!r}."
    )


#
#
#   Update class


[docs] class Update: """ Immutable, chainable site-update program builder. Every instance method appends one operation and returns a **new** ``Update`` object, the original is never mutated. The canonical entry points are the module-level free functions (:func:`shift`, :func:`write`, :func:`swap`, :func:`permute`, :func:`affine`, :func:`scatter`, :func:`identity`) which avoid the ``Update()`` boilerplate. Args: _program: Internal update program (do not pass manually). """ __slots__ = ("_program",)
[docs] def __init__(self, _program: UpdateProgram | None = None) -> None: self._program: UpdateProgram = UpdateProgram() if _program is None else _program
def _append(self, op: UpdateOp) -> "Update": return Update(self._program.append(op)) def to_program(self) -> UpdateProgram: """Returns the underlying immutable :class:`~nkdsl.ir.update.UpdateProgram`.""" return self._program # # # Primitive site mutations def shift( self, site_ref: str | SiteSelector | int | AmplitudeExpr, delta: Any, ) -> "Update": """ Appends ``x'[i] = x[i] + delta``. Args: site_ref: Target site (label string, selector, or flat index). delta: Shift amount, numeric or amplitude expression. Returns: New ``Update`` with this operation appended. """ return self._append( UpdateOp.from_mapping( kind="shift_site", params={ "site": _site_ref_to_expr(site_ref), "delta": coerce_amplitude_expr(delta), }, ) ) def shift_mod( self, site_ref: str | SiteSelector | int | AmplitudeExpr, delta: Any, ) -> "Update": """ Appends a Hilbert-aware wrapped shift. Semantics are resolved from the enclosing operator's Hilbert space at build/compile time. For now this requires contiguous unit-spaced integer local_states such as [-m_max, ..., m_max]. Resulting runtime semantics: x'[i] = ((x[i] + delta - state_min) % mod_span) + state_min """ return self._append( UpdateOp.from_mapping( kind="shift_mod_site", params={ "site": _site_ref_to_expr(site_ref), "delta": coerce_amplitude_expr(delta), }, ) ) def hop( self, src: str | SiteSelector | int | AmplitudeExpr, dst: str | SiteSelector | int | AmplitudeExpr, *, amount: Any = 1, ) -> "Update": """ Appends an occupation transfer from *src* to *dst*. ``hop(src, dst, amount=a)`` is equivalent to ``shift(src, -a).shift(dst, +a)``. It is a readability helper for off-diagonal hopping terms in occupation-number models. Args: src: Source site to lower. dst: Destination site to raise. amount: Transferred amount, numeric or amplitude expression. Returns: New ``Update`` with the two shift operations appended. """ amount_expr = coerce_amplitude_expr(amount) return self.shift(src, -amount_expr).shift(dst, amount_expr) def write( self, site_ref: str | SiteSelector | int | AmplitudeExpr, value: Any, ) -> "Update": """ Appends ``x'[i] = value``. Args: site_ref: Target site. value: New quantum number, numeric or amplitude expression. Returns: New ``Update`` with this operation appended. """ return self._append( UpdateOp.from_mapping( kind="write_site", params={ "site": _site_ref_to_expr(site_ref), "value": coerce_amplitude_expr(value), }, ) ) def swap( self, site_a: str | SiteSelector | int | AmplitudeExpr, site_b: str | SiteSelector | int | AmplitudeExpr, ) -> "Update": """ Appends ``x'[a], x'[b] = x[b], x[a]``. Args: site_a: First site. site_b: Second site. Returns: New ``Update`` with this operation appended. """ return self._append( UpdateOp.from_mapping( kind="swap_sites", params={ "site_a": _site_ref_to_expr(site_a), "site_b": _site_ref_to_expr(site_b), }, ) ) def permute( self, *site_refs: str | SiteSelector | int | AmplitudeExpr, ) -> "Update": """ Appends a **cyclic rotation** over K sites. After the operation:: x'[s0] ← x[s1], x'[s1] ← x[s2], ..., x'[sK-1] ← x[s0] All K source values are captured from the current ``x'`` state *before* any writes are applied, so the rotation is atomic. Args: *site_refs: Two or more site references in rotation order. Returns: New ``Update`` with this operation appended. Raises: ValueError: If fewer than 2 site references are provided. """ if len(site_refs) < 2: raise ValueError("permute requires at least 2 site references.") exprs = tuple(_site_ref_to_expr(s) for s in site_refs) return self._append( UpdateOp.from_mapping( kind="permute_sites", params={"sites": exprs}, ) ) def affine( self, site_ref: str | SiteSelector | int | AmplitudeExpr, *, scale: Any, bias: Any = 0, ) -> "Update": """ Appends ``x'[i] = scale * x[i] + bias``. Args: site_ref: Target site. scale: Multiplicative scale, numeric or amplitude expression. bias: Additive bias, numeric or amplitude expression (default 0). Returns: New ``Update`` with this operation appended. """ return self._append( UpdateOp.from_mapping( kind="affine_site", params={ "site": _site_ref_to_expr(site_ref), "scale": coerce_amplitude_expr(scale), "bias": coerce_amplitude_expr(bias), }, ) ) def scatter( self, flat_indices: list[int] | tuple[int, ...], values: list[Any] | tuple[Any, ...], ) -> "Update": """ Appends bulk writes to static flat site indices. For each ``(flat_index, value)`` pair:: x'[flat_index] = value Indices must be compile-time-constant integers (baked into the IR). Values may be arbitrary amplitude expressions. Args: flat_indices: Sequence of static integer site indices. values: Sequence of amplitude expressions (or coercible values). Returns: New ``Update`` with this operation appended. Raises: ValueError: If *flat_indices* and *values* have different lengths. """ flat_indices = tuple(int(i) for i in flat_indices) values = tuple(coerce_amplitude_expr(v) for v in values) if len(flat_indices) != len(values): raise ValueError( f"scatter: flat_indices and values must have the same length; " f"got {len(flat_indices)} indices and {len(values)} values." ) return self._append( UpdateOp.from_mapping( kind="scatter", params={"flat_indices": flat_indices, "values": values}, ) ) def invalidate(self, *, reason: str | None = None) -> "Update": """ Marks this branch as **invalid** (zero matrix element). Useful for boundary conditions: emit a branch unconditionally and let the update program itself decide validity. Args: reason: Optional readable explanation. Returns: New ``Update`` with this operation appended. """ params = {"reason": str(reason)} if reason is not None else None return self._append(UpdateOp.from_mapping(kind="invalidate_branch", params=params)) # # # Conditional update @classmethod def cond( cls, predicate: Any, *, if_true: "Update", if_false: "Update | None" = None, ) -> "Update": """ Returns a new ``Update`` wrapping a JAX-compatible conditional. At lowering time this becomes ``jax.lax.cond(predicate, ...)`` so both branches must produce the same output shape. The ``if_false`` branch defaults to the identity (no site changes) when not provided. Args: predicate: Branch predicate, :class:`~nkdsl.ir.predicates.PredicateExpr` or any coercible value (e.g. ``site("i").value > 0``). if_true: Update program to apply when *predicate* is true. if_false: Update program to apply when *predicate* is false. Defaults to identity (no writes). Returns: New ``Update`` wrapping the conditional. """ pred_expr = coerce_predicate_expr(predicate) then_ops = if_true.to_program().ops else_ops = if_false.to_program().ops if if_false is not None else () op = UpdateOp.from_mapping( kind="cond_branch", params={ "predicate": pred_expr, "then_ops": then_ops, "else_ops": else_ops, }, ) return cls(UpdateProgram(ops=(op,))) # # # Dunder def __repr__(self) -> str: kinds = [op.kind for op in self._program.ops] return f"Update({kinds!r})" def __eq__(self, other: object) -> bool: if not isinstance(other, Update): return NotImplemented return self._program == other.to_program() def __hash__(self) -> int: return hash(self._program)
# Identity sentinel _IDENTITY = Update() # # # Module-level factory functions
[docs] def shift( site_ref: str | SiteSelector | int | AmplitudeExpr, delta: Any, ) -> Update: """ Returns an ``Update`` that shifts site *site_ref* by *delta*. Example:: shift("i", +1) # raise site i by 1 shift(0, -1) # lower flat site 0 by 1 shift("j", site("i").value) # shift j by x[i] """ return _IDENTITY.shift(site_ref, delta)
[docs] def shift_mod( site_ref: str | SiteSelector | int | AmplitudeExpr, delta: Any, ) -> Update: """ Returns an Update performing a Hilbert-aware wrapped modular shift. Example:: shift_mod("i", +1) shift_mod(0, -2) """ return _IDENTITY.shift_mod(site_ref, delta)
def hop( src: str | SiteSelector | int | AmplitudeExpr, dst: str | SiteSelector | int | AmplitudeExpr, *, amount: Any = 1, ) -> Update: """ Returns an ``Update`` transferring occupation from *src* to *dst*. ``hop(src, dst, amount=a)`` is shorthand for ``shift(src, -a).shift(dst, +a)``. Example:: hop("i", "j") # x'[i] = x[i] - 1, x'[j] = x[j] + 1 hop("i", "j", amount=2) # transfer two quanta """ return _IDENTITY.hop(src, dst, amount=amount)
[docs] def write( site_ref: str | SiteSelector | int | AmplitudeExpr, value: Any, ) -> Update: """ Returns an ``Update`` that writes *value* to site *site_ref*. Example:: write("i", 0) # zero site i write(5, site("j").value) # copy x[j] into flat site 5 """ return _IDENTITY.write(site_ref, value)
[docs] def swap( site_a: str | SiteSelector | int | AmplitudeExpr, site_b: str | SiteSelector | int | AmplitudeExpr, ) -> Update: """ Returns an ``Update`` that swaps sites *site_a* and *site_b*. Example:: swap("i", "j") # exchange x[i] and x[j] swap(0, 10) # exchange flat sites 0 and 10 """ return _IDENTITY.swap(site_a, site_b)
[docs] def permute( *site_refs: str | SiteSelector | int | AmplitudeExpr, ) -> Update: """ Returns an ``Update`` performing a cyclic rotation over K sites. Example:: permute("i", "j", "k") # x'[i]←x[j], x'[j]←x[k], x'[k]←x[i] permute(0, 5, 10) # same with flat indices """ return _IDENTITY.permute(*site_refs)
[docs] def affine( site_ref: str | SiteSelector | int | AmplitudeExpr, *, scale: Any, bias: Any = 0, ) -> Update: """ Returns an ``Update`` computing ``x'[i] = scale * x[i] + bias``. Example:: affine("i", scale=2, bias=-1) # x'[i] = 2*x[i] - 1 affine(0, scale=-1, bias=0) # negate flat site 0 """ return _IDENTITY.affine(site_ref, scale=scale, bias=bias)
[docs] def scatter( flat_indices: list[int] | tuple[int, ...], values: list[Any] | tuple[Any, ...], ) -> Update: """ Returns an ``Update`` performing bulk writes to static flat indices. Example:: scatter([0, 10, 20], [1, -1, 0]) # write constant values scatter([0, 10], [site("i").value, 0]) # mixed expr / constant """ return _IDENTITY.scatter(flat_indices, values)
[docs] def identity() -> Update: """ Returns the identity (no-op) ``Update``. Use for diagonal operators where ``x' = x``:: SymbolicDiscreteJaxOperator(hi, "diagonal").globally().emit(identity(), matrix_element=my_expr) """ return _IDENTITY
__all__ = [ "Update", "affine", "hop", "identity", "permute", "scatter", "shift", "shift_mod", "swap", "write", ]