# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unified symbolic operator type produced by the Operator DSL builder."""
from __future__ import annotations
from typing import Any
import numpy as np
from netket.hilbert import DiscreteHilbert
from nkdsl.core.base import AbstractSymbolicOperator
from nkdsl.debug import event as debug_event
from nkdsl.ir.program import SymbolicOperatorIR
def _resolve_dtype(a: str, b: str) -> str:
"""Promotes two dtype strings to their common type."""
return np.result_type(np.dtype(a), np.dtype(b)).name
def _scalar_str(scalar: Any) -> str:
"""Compact string for a numeric scalar used in operator name generation."""
if isinstance(scalar, complex):
return repr(scalar)
if isinstance(scalar, float) and scalar == int(scalar) and abs(scalar) < 1e15:
return str(int(scalar))
return repr(scalar)
def _promote_dtype_for_scalar(dtype: str, scalar: Any) -> str:
"""Returns the dtype required to hold the product of dtype * scalar."""
base = np.dtype(dtype)
scalar_dtype = np.asarray(scalar).dtype
if np.issubdtype(scalar_dtype, np.complexfloating):
return np.result_type(base, scalar_dtype).name
return base.name
def _merge_metadata(a: dict[str, Any] | None, b: dict[str, Any] | None) -> dict[str, Any]:
out = dict(a or {})
for k, v in (b or {}).items():
if k in out and out[k] != v:
raise ValueError(
f"Cannot merge symbolic operator metadata key {k!r} with different values."
)
out[k] = v
return out
[docs]
class SymbolicOperator(AbstractSymbolicOperator):
"""
A symbolic operator built via the :class:`~nkdsl.dsl.op.SymbolicDiscreteJaxOperator` DSL.
``SymbolicOperator`` is the canonical result of ``SymbolicDiscreteJaxOperator(...).build()``.
It holds an ordered list of typed IR terms and provides a ``.compile()``
method to lower them to an executable
:class:`~nkdsl.core.compiled.CompiledOperator`.
Instances are **not** directly executable: calling ``get_conn_padded``
before compilation raises
:class:`~nkdsl.errors.SymbolicOperatorExecutionError`.
Attributes:
name: User-facing operator name.
hilbert: The NetKet Hilbert space.
dtype: Matrix-element dtype.
is_hermitian: Whether this operator is declared Hermitian.
Example::
op = (
SymbolicDiscreteJaxOperator(hi, "hopping")
.for_each_pair("i", "j")
.where(site("i") > 0)
.emit(shift("i", -1).shift("j", +1), matrix_element=1.0)
.build()
)
compiled = op.compile()
xp, mels = compiled.get_conn_padded(x_batch)
"""
__slots__ = ("_ir_terms",)
[docs]
def __init__(
self,
hilbert: DiscreteHilbert,
name: str,
ir_terms: tuple, # tuple[SymbolicIRTerm, ...]
*,
dtype_str: str = "complex64",
is_hermitian: bool = False,
metadata: dict[str, Any] | None = None,
) -> None:
super().__init__(
hilbert,
name=name,
dtype_str=dtype_str,
is_hermitian=is_hermitian,
metadata=metadata,
)
self._ir_terms: tuple = tuple(ir_terms)
@property
def name(self) -> str:
"""User-facing operator name."""
return self._name_val
@property
def term_count(self) -> int:
"""Number of IR terms in this operator."""
return len(self._ir_terms)
@property
def free_symbols(self) -> frozenset:
"""
Returns the set of free (non-iterator-bound) symbol names across all terms.
Free symbols are unresolved named parameters such as ``symbol("kappa")``.
Symbols declared with ``default=...`` are considered resolved and are not
included in this set.
"""
result: set = set()
for t in self._ir_terms:
result |= t.free_symbols
return frozenset(result)
def to_ir(self) -> SymbolicOperatorIR:
"""Builds the symbolic operator IR for compiler consumption."""
debug_event(
"materializing symbolic operator ir",
scope="ir",
tag="IR",
operator_name=self._name_val,
term_count=len(self._ir_terms),
dtype=self._dtype_val,
)
ir = SymbolicOperatorIR.from_terms(
operator_name=self._name_val,
hilbert_size=int(self.hilbert.size),
dtype_str=self._dtype_val,
is_hermitian=self._is_hermitian_val,
terms=self._ir_terms,
metadata=self._metadata_dict if self._metadata_dict else None,
)
debug_event(
"materialized symbolic operator ir",
scope="ir",
tag="IR",
operator_name=ir.operator_name,
term_count=ir.term_count,
free_symbol_count=len(ir.free_symbols),
)
return ir
def estimate_max_conn_size(self) -> int:
"""Returns the aggregate static max-connection bound across all terms."""
from nkdsl.ir.term import KBodyIteratorSpec
total = 0
for t in self._ir_terms:
if t.max_conn_size_hint is not None:
total += int(t.max_conn_size_hint)
else:
if not isinstance(t.iterator, KBodyIteratorSpec):
raise TypeError(
f"Unsupported iterator type {type(t.iterator).__name__!r}; "
"expected KBodyIteratorSpec."
)
E = len(t.effective_emissions)
M = len(t.iterator.index_sets)
total += M * E
return max(1, total)
def compile(
self,
*,
backend: str = "jax",
operator_lowering: str = "netket_discrete_jax",
deduplicate_connected_components: bool = True,
cache: bool = True,
compiler: Any = None,
) -> Any: # -> CompiledOperator
"""Lowers this symbolic operator to an executable :class:`~nkdsl.core.compiled.CompiledOperator`.
Args:
backend: Backend target (currently only ``"jax"`` is supported).
operator_lowering: Registered operator-lowering target name.
deduplicate_connected_components: Whether to merge duplicate
connected states and drop zero matrix elements at runtime.
cache: Whether to cache the compiled artifact in the process-level store.
compiler: Optional :class:`~nkdsl.compiler.SymbolicCompiler`
instance. When ``None`` the module-level shared compiler is used.
Returns:
Executable :class:`~nkdsl.core.compiled.CompiledOperator`.
"""
from nkdsl.compiler.compiler import (
SymbolicCompiler,
)
from nkdsl.compiler.core.options import (
SymbolicCompilerOptions,
)
c = compiler or SymbolicCompiler(
options=SymbolicCompilerOptions(
backend_preference=backend,
operator_lowering=operator_lowering,
deduplicate_connected_components=deduplicate_connected_components,
cache_enabled=cache,
)
)
debug_event(
"compiling symbolic operator",
scope="compile",
tag="COMPILER",
operator_name=self._name_val,
backend=backend,
operator_lowering=operator_lowering,
deduplicate_connected_components=deduplicate_connected_components,
cache=cache,
)
return c.compile_operator(self)
def _apply_scalar(self, scalar: "int | float | complex") -> "SymbolicOperator":
from nkdsl.ir.expressions import (
AmplitudeExpr,
)
from nkdsl.ir.term import _scale_ir_term
scale_expr = AmplitudeExpr.constant(scalar)
new_terms = tuple(_scale_ir_term(t, scale_expr) for t in self._ir_terms)
is_hermitian = self._is_hermitian_val and not isinstance(scalar, complex)
new_dtype = _promote_dtype_for_scalar(self._dtype_val, scalar)
scaled = SymbolicOperator(
self.hilbert,
f"({_scalar_str(scalar)} * {self.name})",
new_terms,
dtype_str=new_dtype,
is_hermitian=is_hermitian,
metadata=self._metadata_dict or None,
)
debug_event(
"scaled symbolic operator",
scope="dsl",
tag="DSL",
source_operator=self.name,
scalar=scalar,
target_operator=scaled.name,
)
return scaled
def __add__(self, other: Any):
"""Compose with another operator using NetKet sum machinery."""
return super().__add__(other)
def __radd__(self, other: Any):
return super().__radd__(other)
def __repr__(self) -> str:
return (
f"SymbolicOperator("
f"name={self.name!r}, "
f"terms={self.term_count}, "
f"dtype={self._dtype_val!r}, "
f"hilbert={self.hilbert})"
)
__all__ = ["SymbolicOperator"]