Source code for nkdsl.ir.predicates

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Typed predicate-expression IR nodes for symbolic operators."""

from __future__ import annotations

import dataclasses
from typing import Any

from .expressions import AmplitudeExpr
from .expressions import coerce_amplitude_expr

_PREDICATE_OPS: frozenset[str] = frozenset(
    {
        "const",
        "not",
        "and",
        "or",
        "eq",
        "ne",
        "lt",
        "le",
        "gt",
        "ge",
    }
)


[docs] @dataclasses.dataclass(frozen=True, repr=False) class PredicateExpr: """ Typed boolean expression node for operator branch filtering. Attributes: op: Predicate operation name. args: Ordered operation arguments. """ op: str args: tuple = dataclasses.field(default_factory=tuple) def __post_init__(self) -> None: if self.op not in _PREDICATE_OPS: raise ValueError( f"Unsupported predicate-expression op: {self.op!r}. " f"Allowed: {sorted(_PREDICATE_OPS)}." ) @classmethod def constant(cls, value: bool) -> "PredicateExpr": """Builds a constant predicate expression.""" return cls(op="const", args=(bool(value),)) @classmethod def not_(cls, operand: Any) -> "PredicateExpr": """Builds a logical-negation predicate.""" return cls(op="not", args=(coerce_predicate_expr(operand),)) @classmethod def and_(cls, *operands: Any) -> "PredicateExpr": """Builds a logical conjunction predicate.""" if not operands: return cls.constant(True) normalized = tuple(coerce_predicate_expr(item) for item in operands) return cls(op="and", args=normalized) @classmethod def or_(cls, *operands: Any) -> "PredicateExpr": """Builds a logical disjunction predicate.""" if not operands: return cls.constant(False) normalized = tuple(coerce_predicate_expr(item) for item in operands) return cls(op="or", args=normalized) @classmethod def eq(cls, left: Any, right: Any) -> "PredicateExpr": """Builds an equality predicate.""" return cls( op="eq", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) @classmethod def ne(cls, left: Any, right: Any) -> "PredicateExpr": """Builds an inequality predicate.""" return cls( op="ne", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) @classmethod def lt(cls, left: Any, right: Any) -> "PredicateExpr": """Builds a strict-less-than predicate.""" return cls( op="lt", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) @classmethod def le(cls, left: Any, right: Any) -> "PredicateExpr": """Builds a less-than-or-equal predicate.""" return cls( op="le", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) @classmethod def gt(cls, left: Any, right: Any) -> "PredicateExpr": """Builds a strict-greater-than predicate.""" return cls( op="gt", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) @classmethod def ge(cls, left: Any, right: Any) -> "PredicateExpr": """Builds a greater-than-or-equal predicate.""" return cls( op="ge", args=(coerce_amplitude_expr(left), coerce_amplitude_expr(right)), ) def __and__(self, other: Any) -> "PredicateExpr": return self.and_(self, other) def __rand__(self, other: Any) -> "PredicateExpr": return self.and_(other, self) def __or__(self, other: Any) -> "PredicateExpr": return self.or_(self, other) def __ror__(self, other: Any) -> "PredicateExpr": return self.or_(other, self) def __invert__(self) -> "PredicateExpr": return self.not_(self) def __str__(self) -> str: return _render_predicate(self) def __repr__(self) -> str: return f"PredicateExpr(op={self.op!r}, args={self.args!r})"
def _render_predicate(expr: "PredicateExpr") -> str: """Renders a PredicateExpr as a readable infix boolean string.""" from .expressions import _render_amplitude, AmplitudeExpr op = expr.op args = expr.args if op == "const": return "true" if args[0] else "false" if op == "not": return f"!{_render_predicate(args[0])}" if op == "and": parts = [_render_predicate(a) if isinstance(a, PredicateExpr) else repr(a) for a in args] return f"({' && '.join(parts)})" if op == "or": parts = [_render_predicate(a) if isinstance(a, PredicateExpr) else repr(a) for a in args] return f"({' || '.join(parts)})" # Comparison ops: args are two AmplitudeExprs _OPS = {"eq": "==", "ne": "!=", "lt": "<", "le": "<=", "gt": ">", "ge": ">="} if op in _OPS: lhs = _render_amplitude(args[0]) if isinstance(args[0], AmplitudeExpr) else repr(args[0]) rhs = _render_amplitude(args[1]) if isinstance(args[1], AmplitudeExpr) else repr(args[1]) return f"({lhs} {_OPS[op]} {rhs})" # Fallback arg_strs = ", ".join( _render_predicate(a) if isinstance(a, PredicateExpr) else repr(a) for a in args ) return f"{op}({arg_strs})" def _collect_free_symbols_pred(expr: "PredicateExpr", result: "set[str]") -> None: """Recursively collects free symbol names from a PredicateExpr.""" from .expressions import AmplitudeExpr, _collect_free_symbols for arg in expr.args: if isinstance(arg, PredicateExpr): _collect_free_symbols_pred(arg, result) elif isinstance(arg, AmplitudeExpr): _collect_free_symbols(arg, result) def coerce_predicate_expr(value: Any) -> PredicateExpr: """ Coerces user values into typed predicate-expression nodes. Args: value: Input predicate value. Returns: Typed predicate expression. Raises: TypeError: If ``value`` cannot be converted. """ if isinstance(value, PredicateExpr): return value if isinstance(value, bool): return PredicateExpr.constant(value) if isinstance(value, AmplitudeExpr): raise TypeError( "Cannot use an AmplitudeExpr directly as a predicate. " "Use an explicit comparison, e.g. expr > 0." ) raise TypeError( f"Cannot coerce {type(value)!r} into a PredicateExpr. " "Use bool values or PredicateExpr objects." ) __all__ = [ "PredicateExpr", "coerce_predicate_expr", "_collect_free_symbols_pred", "_render_predicate", ]