Source code for nkdsl.ir.program

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Top-level symbolic operator IR container."""

from __future__ import annotations

import dataclasses
import hashlib
import json
from typing import Any

from nkdsl.debug import event as debug_event

from .term import SymbolicIRTerm

_IR_MODES: frozenset[str] = frozenset({"symbolic", "jax_kernel"})


def _serialize_amplitude(expr: Any) -> Any:
    """Recursively serializes an AmplitudeExpr to a JSON-safe structure."""
    from .expressions import AmplitudeExpr

    if isinstance(expr, AmplitudeExpr):
        return {"op": expr.op, "args": [_serialize_amplitude(a) for a in expr.args]}
    if isinstance(expr, tuple):
        return [_serialize_amplitude(v) for v in expr]
    return expr


def _serialize_predicate(expr: Any) -> Any:
    """Recursively serializes a PredicateExpr to a JSON-safe structure."""
    from .predicates import PredicateExpr
    from .expressions import AmplitudeExpr

    if isinstance(expr, PredicateExpr):
        return {"op": expr.op, "args": [_serialize_predicate(a) for a in expr.args]}
    if isinstance(expr, AmplitudeExpr):
        return _serialize_amplitude(expr)
    if isinstance(expr, tuple):
        return [_serialize_predicate(v) for v in expr]
    return expr


def _serialize_update(program: Any) -> Any:
    """Serializes an UpdateProgram to a JSON-safe structure."""
    from .update import UpdateProgram, UpdateOp

    if not isinstance(program, UpdateProgram):
        return repr(program)
    ops = []
    for op in program.ops:
        params_dict = {}
        for k, v in op.params:
            params_dict[k] = _serialize_amplitude(v)
        ops.append({"kind": op.kind, "params": params_dict})
    return ops


[docs] @dataclasses.dataclass(frozen=True, repr=False) class SymbolicOperatorIR: """ Immutable symbolic operator IR container. Attributes: operator_name: Name of the operator this IR represents. mode: IR mode (``symbolic`` for DSL-built operators, ``jax_kernel`` for direct JAX-kernel operators). hilbert_size: Size of the Hilbert space (number of sites). dtype_str: String representation of the matrix-element dtype. is_hermitian: Whether the source operator is declared Hermitian. terms: Declarative term tuple for ``symbolic`` mode. metadata: Optional stable metadata tuple. """ operator_name: str mode: str hilbert_size: int dtype_str: str is_hermitian: bool terms: tuple = dataclasses.field(default_factory=tuple) metadata: tuple = dataclasses.field(default_factory=tuple) def __post_init__(self) -> None: if not self.operator_name.strip(): raise ValueError("operator_name must be a non-empty string.") if self.mode not in _IR_MODES: raise ValueError(f"Unsupported IR mode: {self.mode!r}. Allowed: {sorted(_IR_MODES)}.") if self.hilbert_size <= 0: raise ValueError(f"hilbert_size must be a positive integer; got {self.hilbert_size!r}.") @classmethod def from_terms( cls, *, operator_name: str, hilbert_size: int, dtype_str: str, is_hermitian: bool, terms: tuple[SymbolicIRTerm, ...], metadata: dict[str, Any] | None = None, ) -> "SymbolicOperatorIR": """Builds declarative symbolic-mode operator IR.""" if not terms: raise ValueError("Symbolic operator IR requires at least one term.") meta_tuple: tuple if metadata is None: meta_tuple = () else: meta_tuple = tuple(sorted(metadata.items())) ir = cls( operator_name=str(operator_name), mode="symbolic", hilbert_size=int(hilbert_size), dtype_str=str(dtype_str), is_hermitian=bool(is_hermitian), terms=terms, metadata=meta_tuple, ) debug_event( "constructed symbolic operator ir", scope="ir", tag="IR", operator_name=ir.operator_name, term_count=ir.term_count, metadata_keys=tuple(k for k, _ in ir.metadata), ) return ir @property def term_count(self) -> int: """Returns number of declarative terms.""" return len(self.terms) def metadata_dict(self) -> dict[str, Any]: """Returns metadata in dictionary form.""" return dict(self.metadata) def as_dict(self) -> dict[str, Any]: """Returns a JSON-serializable dictionary representation.""" return { "operator_name": self.operator_name, "mode": self.mode, "hilbert_size": self.hilbert_size, "dtype_str": self.dtype_str, "is_hermitian": self.is_hermitian, "terms": [ { "name": t.name, "iterator": { "kind": t.iterator.kind, "label_a": t.iterator.label_a, "label_b": t.iterator.label_b, }, "predicate": _serialize_predicate(t.predicate), "update_program": _serialize_update(t.update_program), "amplitude": _serialize_amplitude(t.amplitude), "branch_tag": t.branch_tag, "emissions": [ { "predicate": _serialize_predicate(em.predicate), "update_program": _serialize_update(em.update_program), "amplitude": _serialize_amplitude(em.amplitude), "branch_tag": em.branch_tag, } for em in t.effective_emissions ], "metadata": list(t.metadata), "max_conn_size_hint": t.max_conn_size_hint, } for t in self.terms ], "metadata": list(self.metadata), } def static_fingerprint(self) -> str: """Returns a deterministic SHA-256 digest over the static IR payload.""" raw = json.dumps(self.as_dict(), sort_keys=True, default=str) digest = hashlib.sha256(raw.encode("utf-8")).hexdigest() debug_event( "computed symbolic ir fingerprint", scope="ir", tag="IR", operator_name=self.operator_name, fingerprint_prefix=digest[:16], ) return digest @property def free_symbols(self) -> frozenset: """ Returns the union of free symbol names across all terms. Free symbols are named parameters (e.g. ``symbol("kappa")``) that are not bound by any iterator label and still unresolved. Symbols declared with ``default=...`` are treated as resolved and are not included. """ result: set = set() for term in self.terms: result |= term.free_symbols return frozenset(result) def __str__(self) -> str: """ Returns a structured IR dump in the symbolic IR format. The format is inspired by LLVM IR / MLIR. Named operator blocks with typed terms, readable iterator descriptions, infix amplitude expressions, and pseudocode update programs. Example output:: symbolic.operator @"hopping" [dtype=complex64, hermitian=false, hilbert_size=16] { ; 1 term(s) term #0 "0" [kbody, n_iter=256, max_conn_size=256] { iterate: for (i, j) in [(0, 0), (0, 1), (0, 2), ... +253 more] where: (x[i] > 0) emit #0: update: x'[i] = (x[i] + -1); x'[j] = (x[j] + 1) amplitude: 1 } } """ hermitian_str = "true" if self.is_hermitian else "false" lines = [ f'symbolic.operator @"{self.operator_name}" ' f"[dtype={self.dtype_str}, hermitian={hermitian_str}, " f"hilbert_size={self.hilbert_size}] {{" ] fs = self.free_symbols if fs: fs_str = ", ".join(f"%{s}" for s in sorted(fs)) lines.append(f" ; {len(self.terms)} term(s), free symbols: [{fs_str}]") else: lines.append(f" ; {len(self.terms)} term(s)") for idx, term in enumerate(self.terms): lines.append("") lines.extend(term.to_ir_lines(idx=idx, indent=" ")) lines.append("") lines.append("}") return "\n".join(lines) def __repr__(self) -> str: return ( f"SymbolicOperatorIR(" f"operator_name={self.operator_name!r}, " f"mode={self.mode!r}, " f"hilbert_size={self.hilbert_size}, " f"term_count={self.term_count}, " f"is_hermitian={self.is_hermitian})" )
__all__ = ["SymbolicOperatorIR"]