# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Build-time expression context for symbolic operator DSL callbacks."""
from __future__ import annotations
from typing import Any
from nkdsl.dsl.selectors import SiteSelector
from nkdsl.dsl.selectors import emitted
from nkdsl.dsl.selectors import site
from nkdsl.dsl.selectors import source_index
from nkdsl.dsl.selectors import symbol
from nkdsl.dsl.selectors import target_index
from nkdsl.ir.expressions import AmplitudeExpr
from nkdsl.ir.expressions import (
coerce_amplitude_expr,
)
from nkdsl.ir.predicates import PredicateExpr
from nkdsl.ir.predicates import coerce_predicate_expr
_UNSET_SYMBOL_DEFAULT = object()
[docs]
class ExpressionContext:
"""
Utility context passed to DSL callables at build time.
This context only builds IR expression nodes and never captures Python-runtime callbacks.
Example:
>>> def my_amplitude(ctx):
... i = ctx.site("i")
... return ctx.sqrt(i.value + 1)
"""
__slots__ = ()
def const(self, value: complex | float | int) -> AmplitudeExpr:
"""Returns a constant amplitude expression."""
return AmplitudeExpr.constant(value)
def symbol(
self,
name: str,
*,
default: Any = _UNSET_SYMBOL_DEFAULT,
doc: str = "",
dtype: Any | None = None,
) -> AmplitudeExpr:
"""Returns a free symbolic amplitude expression.
Args:
name: Symbol name.
default: Optional default value.
doc: Optional symbol documentation string.
dtype: Optional NumPy-compatible symbol dtype declaration.
"""
if default is not _UNSET_SYMBOL_DEFAULT:
return symbol(name, default=default, doc=doc, dtype=dtype)
return symbol(name, doc=doc, dtype=dtype)
def sqrt(self, operand: Any) -> AmplitudeExpr:
"""Returns a square-root amplitude expression."""
return AmplitudeExpr.sqrt(operand)
def conj(self, operand: Any) -> AmplitudeExpr:
"""Returns a complex-conjugate amplitude expression."""
return AmplitudeExpr.conj(operand)
def neg(self, operand: Any) -> AmplitudeExpr:
"""Returns a negated amplitude expression."""
return AmplitudeExpr.neg(operand)
def site(self, label: str) -> SiteSelector:
"""Returns a site selector by label."""
return site(label)
def emitted(self, label: str) -> SiteSelector:
"""Returns an emitted-state selector by label."""
return emitted(label)
def source_index(self, flat_index: int) -> AmplitudeExpr:
"""
Returns a static source-configuration read ``x[flat_index]``.
Args:
flat_index: Non-negative flat index into the source configuration ``x``.
Returns:
Static-index amplitude expression.
"""
return source_index(flat_index)
def target_index(self, flat_index: int) -> AmplitudeExpr:
"""
Returns a static emitted/target-configuration read ``x'[flat_index]``.
Args:
flat_index: Non-negative flat index into the emitted configuration ``x'``.
Returns:
Static-emitted-index amplitude expression.
"""
return target_index(flat_index)
def all_of(self, *operands: Any) -> PredicateExpr:
"""Builds logical conjunction over provided predicate values."""
return PredicateExpr.and_(*operands)
def any_of(self, *operands: Any) -> PredicateExpr:
"""Builds logical disjunction over provided predicate values."""
return PredicateExpr.or_(*operands)
def not_(self, operand: Any) -> PredicateExpr:
"""Builds logical negation over one predicate value."""
return PredicateExpr.not_(operand)
def eq(self, left: Any, right: Any) -> PredicateExpr:
"""Builds equality predicate expression."""
return PredicateExpr.eq(left, right)
def ne(self, left: Any, right: Any) -> PredicateExpr:
"""Builds inequality predicate expression."""
return PredicateExpr.ne(left, right)
def lt(self, left: Any, right: Any) -> PredicateExpr:
"""Builds strict-less-than predicate expression."""
return PredicateExpr.lt(left, right)
def le(self, left: Any, right: Any) -> PredicateExpr:
"""Builds less-than-or-equal predicate expression."""
return PredicateExpr.le(left, right)
def gt(self, left: Any, right: Any) -> PredicateExpr:
"""Builds strict-greater-than predicate expression."""
return PredicateExpr.gt(left, right)
def ge(self, left: Any, right: Any) -> PredicateExpr:
"""Builds greater-than-or-equal predicate expression."""
return PredicateExpr.ge(left, right)
def coerce_amplitude(self, value: Any) -> AmplitudeExpr:
"""Coerces one value into an amplitude expression."""
return coerce_amplitude_expr(value)
def coerce_predicate(self, value: Any) -> PredicateExpr:
"""Coerces one value into a predicate expression."""
return coerce_predicate_expr(value)
def pow(self, base: Any, exponent: Any) -> AmplitudeExpr:
"""Returns a power expression ``base ** exponent``."""
return AmplitudeExpr.pow(base, exponent)
def abs_(self, operand: Any) -> AmplitudeExpr:
"""Returns an absolute-value expression ``|operand|``."""
return AmplitudeExpr.abs_(operand)
def wrap_mod(self, operand: Any) -> AmplitudeExpr:
"""Returns a Hilbert-aware modulo-wrapped amplitude expression."""
return AmplitudeExpr.wrap_mod(operand)
def sq_norm(self, *components: Any) -> AmplitudeExpr:
"""
Returns the squared L2 norm of *components*: ``c0² + c1² + ...``.
Each component is coerced to an :class:`AmplitudeExpr` before squaring.
Equivalent to ``sum(c * c for c in components)``.
Args:
*components: Two or more amplitude expressions (or coercible values).
Returns:
Amplitude expression for the squared norm.
"""
if not components:
raise ValueError("sq_norm requires at least one component.")
exprs = [coerce_amplitude_expr(c) for c in components]
result = AmplitudeExpr.mul(exprs[0], exprs[0])
for e in exprs[1:]:
result = AmplitudeExpr.add(result, AmplitudeExpr.mul(e, e))
return result
def norm2(self, *components: Any) -> AmplitudeExpr:
"""
Returns the L2 norm of *components*: ``sqrt(c0² + c1² + ...)``.
Args:
*components: Two or more amplitude expressions (or coercible values).
Returns:
Amplitude expression for the Euclidean norm.
"""
return AmplitudeExpr.sqrt(self.sq_norm(*components))
def edge_value(
self,
edge_idx: int,
gauge_copy: int,
n_edges_per_copy: int,
) -> AmplitudeExpr:
"""
Returns the charge at ``x[gauge_copy * n_edges_per_copy + edge_idx]``.
This is the primary access pattern for U(1)^G (or SU(N)^G) gauge
theories where the flat configuration array is laid out as G consecutive
blocks of ``n_edges_per_copy`` entries each.
Args:
edge_idx: Zero-based edge index within one gauge copy.
gauge_copy: Gauge copy index (0-based).
n_edges_per_copy: Number of edge sites per gauge copy.
Returns:
Static-index amplitude expression reading ``x[g * E + e]``.
"""
flat = int(gauge_copy) * int(n_edges_per_copy) + int(edge_idx)
return AmplitudeExpr.static_index(flat)
def emitted_edge_value(
self,
edge_idx: int,
gauge_copy: int,
n_edges_per_copy: int,
) -> AmplitudeExpr:
"""Returns the emitted/connected charge at x'[g*E + e]."""
flat = int(gauge_copy) * int(n_edges_per_copy) + int(edge_idx)
return AmplitudeExpr.static_emitted_index(flat)
def edge_components(
self,
edge_idx: int,
n_edges_per_copy: int,
gauge_dim: int = 3,
) -> list[AmplitudeExpr]:
"""
Returns all *gauge_dim* charge components for one edge.
For a U(1)^G Hilbert space with G gauge copies, returns the list
``[x[0*E + e], x[1*E + e], …, x[(G-1)*E + e]]``.
Args:
edge_idx: Zero-based edge index.
n_edges_per_copy: Number of edges per gauge copy (E).
gauge_dim: Number of gauge copies (G, default 3 for U(1)^3).
Returns:
List of G static-index amplitude expressions.
"""
return [self.edge_value(edge_idx, g, n_edges_per_copy) for g in range(gauge_dim)]
def edge_sq_norm(
self,
edge_idx: int,
n_edges_per_copy: int,
gauge_dim: int = 3,
) -> AmplitudeExpr:
"""
Returns the squared L2 norm of the charge vector at *edge_idx*.
Equivalent to ``sq_norm(*edge_components(edge_idx, ...))``.
"""
return self.sq_norm(*self.edge_components(edge_idx, n_edges_per_copy, gauge_dim))
def edge_norm(
self,
edge_idx: int,
n_edges_per_copy: int,
gauge_dim: int = 3,
) -> AmplitudeExpr:
"""
Returns the L2 norm of the charge vector at *edge_idx*.
Equivalent to ``norm2(*edge_components(edge_idx, ...))``.
"""
return self.norm2(*self.edge_components(edge_idx, n_edges_per_copy, gauge_dim))
__all__ = ["ExpressionContext"]