Custom Predicates

Predicates define when a term (or branch) is active over the iterator domain. This notebook focuses on writing reusable predicate clauses that feel native in the fluent API and produce predictable behavior in compiled operators.

We will implement two clauses, compose them with built-in where(...), and inspect how predicate logic shows up in execution and IR.

Predicate semantics in one page

A predicate clause is just a way to build a boolean expression and append it to the current term. Multiple predicate clauses compose with logical AND.

That means a chain like:

for_each_site("i").at_least_occupancy("i", 2).where(site("i") != 3)

is interpreted as one combined guard:

(site(i) >= 2) AND (site(i) != 3)

This is important for readability and for debugging: treat each clause as one reusable condition block.

import jax.numpy as jnp
import netket as nk
import nkdsl

hi = nk.hilbert.Fock(n_max=3, N=5)
print("Hilbert size:", hi.size)
print("Built-in predicate clauses:", sorted(nkdsl.available_predicate_clause_names()))
∣NK⟩ Tip: uv is a replacement for pip which helps you follow good software practices.
Hilbert size: 5
Built-in predicate clauses: ['where']

Example 1: a reusable occupancy threshold

This clause exposes a domain concept directly in the DSL: “site occupancy must be at least cutoff”. We implement it once, then use it in any operator without rewriting expression code each time.

class AtLeastOccupancy(nkdsl.AbstractPredicateClause):
    clause_name = "at_least_occupancy"

    def build_predicate(self, ctx, label: str = "i", cutoff: int = 1):
        return ctx.site(label).value >= int(cutoff)


nkdsl.register_predicate_clause(AtLeastOccupancy, replace=True)
print("Registered:", "at_least_occupancy" in nkdsl.available_predicate_clause_names())
Registered: True
op_threshold = (
    nkdsl.SymbolicDiscreteJaxOperator(hi, "threshold")
    .for_each_site("i")
    .at_least_occupancy("i", cutoff=2)
    .emit(nkdsl.identity(), matrix_element=nkdsl.site("i").value)
    .build()
    .compile()
)

x = jnp.asarray([0, 1, 2, 3, 1], dtype=jnp.int32)
_xp, mels = op_threshold.get_conn_padded(x)
print("mels:", mels)
mels: [5. 0. 0. 0. 0.]

Interpretation: matrix elements are nonzero only for iterator rows where the predicate is true. Rows still exist in padded output shape, but inactive rows contribute zero amplitude.

This is why predicates are a clean way to express physical selection rules while preserving static shapes.

Example 2: bounded-value predicate with argument validation

When a clause has user parameters, validate them early. Here we reject lower > upper immediately so you get a direct error message near the call site.

class ValueInBand(nkdsl.AbstractPredicateClause):
    clause_name = "value_in_band"

    def build_predicate(self, ctx, label: str = "i", *, lower: int, upper: int):
        lo = int(lower)
        hi_ = int(upper)
        if lo > hi_:
            raise ValueError("lower must be <= upper.")
        v = ctx.site(label).value
        return (v >= lo) & (v <= hi_)


nkdsl.register_predicate_clause(ValueInBand, replace=True)
__main__.ValueInBand
op_banded_sym = (
    nkdsl.SymbolicDiscreteJaxOperator(hi, "banded")
    .for_each_site("i")
    .value_in_band("i", lower=1, upper=2)
    .where(nkdsl.site("i") != 2)
    .emit(nkdsl.identity(), matrix_element=1.0)
    .build()
)
op_banded = op_banded_sym.compile()

for state in [
    jnp.asarray([0, 1, 2, 1, 3], dtype=jnp.int32),
    jnp.asarray([2, 2, 1, 0, 1], dtype=jnp.int32),
]:
    _xp2, mels2 = op_banded.get_conn_padded(state)
    print("x=", state.tolist(), "-> mels=", mels2)
x= [0, 1, 2, 1, 3] -> mels= [2. 0. 0. 0. 0.]
x= [2, 2, 1, 0, 1] -> mels= [2. 0. 0. 0. 0.]
print(op_banded_sym.to_ir())
symbolic.operator @"banded" [dtype=float64, hermitian=false, hilbert_size=5] {
  ; 1 term(s)

  term #0 "0" [kbody, n_iter=5, max_conn_size=5] {
    iterate: for (i,) in [(0,), (1,), (2,), ... +2 more]
    where:   (((x[i] >= 1) && (x[i] <= 2)) && (x[i] != 2))
    emit #0:
      update:    identity
      amplitude: 1
  }

}

Reading tip: in the IR, look at where: for each term to confirm the final composed predicate. If the textual predicate does not match your intent, the bug is usually in clause construction (or in call ordering), not in lowering.

Design checklist for robust predicate clauses

  • Keep one clause focused on one concept (threshold, band, parity, membership, etc.).

  • Use explicit names and argument defaults that read well in fluent chains.

  • Validate parameters inside build_predicate.

  • Return only boolean-coercible expression values.

  • Print IR during development to verify composition behavior.

If you follow those rules, custom predicates stay predictable and easy.