{ "cells": [ { "cell_type": "markdown", "id": "d285bdd3307edf4b", "metadata": {}, "source": [ "# Custom Predicates\n", "\n", "Predicates define **when** a term (or branch) is active over the iterator domain.\n", "This notebook focuses on writing reusable predicate clauses that feel native in the\n", "fluent API and produce predictable behavior in compiled operators.\n", "\n", "We will implement two clauses, compose them with built-in `where(...)`, and inspect\n", "how predicate logic shows up in execution and IR.\n" ] }, { "cell_type": "markdown", "id": "d5251d265bcf08a5", "metadata": {}, "source": [ "## Predicate semantics in one page\n", "\n", "A predicate clause is just a way to build a boolean expression and append it to the current term.\n", "Multiple predicate clauses compose with logical AND.\n", "\n", "That means a chain like:\n", "\n", "`for_each_site(\"i\").at_least_occupancy(\"i\", 2).where(site(\"i\") != 3)`\n", "\n", "is interpreted as one combined guard:\n", "\n", "`(site(i) >= 2) AND (site(i) != 3)`\n", "\n", "This is important for readability and for debugging: treat each clause as one reusable condition block.\n" ] }, { "cell_type": "code", "id": "ee6e020ee68b85a3", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:09.774571Z", "start_time": "2026-05-02T11:57:08.766404Z" } }, "source": [ "import jax.numpy as jnp\n", "import netket as nk\n", "import nkdsl\n", "\n", "hi = nk.hilbert.Fock(n_max=3, N=5)\n", "print(\"Hilbert size:\", hi.size)\n", "print(\"Built-in predicate clauses:\", sorted(nkdsl.available_predicate_clause_names()))\n" ], "outputs": [ { "data": { "text/plain": [ "\u001B[1;36m∣NK⟩ Tip: \u001B[0muv is a replacement for pip which helps you follow good software practices.\n" ], "text/html": [ "
∣NK⟩ Tip: uv is a replacement for pip which helps you follow good software practices.\n",
"\n"
]
},
"metadata": {},
"output_type": "display_data",
"jetTransient": {
"display_id": null
}
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Hilbert size: 5\n",
"Built-in predicate clauses: ['where']\n"
]
}
],
"execution_count": 1
},
{
"cell_type": "markdown",
"id": "50f3f6cff2ded4fc",
"metadata": {},
"source": [
"## Example 1: a reusable occupancy threshold\n",
"\n",
"This clause exposes a domain concept directly in the DSL: \"site occupancy must be at least cutoff\".\n",
"We implement it once, then use it in any operator without rewriting expression code each time.\n"
]
},
{
"cell_type": "code",
"id": "6acf2bb7984cc2e0",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:09.811987Z",
"start_time": "2026-05-02T11:57:09.787997Z"
}
},
"source": [
"class AtLeastOccupancy(nkdsl.AbstractPredicateClause):\n",
" clause_name = \"at_least_occupancy\"\n",
"\n",
" def build_predicate(self, ctx, label: str = \"i\", cutoff: int = 1):\n",
" return ctx.site(label).value >= int(cutoff)\n",
"\n",
"\n",
"nkdsl.register_predicate_clause(AtLeastOccupancy, replace=True)\n",
"print(\"Registered:\", \"at_least_occupancy\" in nkdsl.available_predicate_clause_names())\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Registered: True\n"
]
}
],
"execution_count": 2
},
{
"cell_type": "code",
"id": "7d03cbba540cffee",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:10.571121Z",
"start_time": "2026-05-02T11:57:09.812469Z"
}
},
"source": [
"op_threshold = (\n",
" nkdsl.SymbolicDiscreteJaxOperator(hi, \"threshold\")\n",
" .for_each_site(\"i\")\n",
" .at_least_occupancy(\"i\", cutoff=2)\n",
" .emit(nkdsl.identity(), matrix_element=nkdsl.site(\"i\").value)\n",
" .build()\n",
" .compile()\n",
")\n",
"\n",
"x = jnp.asarray([0, 1, 2, 3, 1], dtype=jnp.int32)\n",
"_xp, mels = op_threshold.get_conn_padded(x)\n",
"print(\"mels:\", mels)\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mels: [5. 0. 0. 0. 0.]\n"
]
}
],
"execution_count": 3
},
{
"cell_type": "markdown",
"id": "da1767453de2e802",
"metadata": {},
"source": [
"Interpretation: matrix elements are nonzero only for iterator rows where the predicate is true.\n",
"Rows still exist in padded output shape, but inactive rows contribute zero amplitude.\n",
"\n",
"This is why predicates are a clean way to express physical selection rules while preserving static shapes.\n"
]
},
{
"cell_type": "markdown",
"id": "b2e1f996c466483e",
"metadata": {},
"source": [
"## Example 2: bounded-value predicate with argument validation\n",
"\n",
"When a clause has user parameters, validate them early. Here we reject `lower > upper`\n",
"immediately so you get a direct error message near the call site.\n"
]
},
{
"cell_type": "code",
"id": "15115855a1dd0d50",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:10.601759Z",
"start_time": "2026-05-02T11:57:10.597612Z"
}
},
"source": [
"class ValueInBand(nkdsl.AbstractPredicateClause):\n",
" clause_name = \"value_in_band\"\n",
"\n",
" def build_predicate(self, ctx, label: str = \"i\", *, lower: int, upper: int):\n",
" lo = int(lower)\n",
" hi_ = int(upper)\n",
" if lo > hi_:\n",
" raise ValueError(\"lower must be <= upper.\")\n",
" v = ctx.site(label).value\n",
" return (v >= lo) & (v <= hi_)\n",
"\n",
"\n",
"nkdsl.register_predicate_clause(ValueInBand, replace=True)\n"
],
"outputs": [
{
"data": {
"text/plain": [
"__main__.ValueInBand"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"execution_count": 4
},
{
"cell_type": "code",
"id": "4c50a55b7f77789d",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:11.516021Z",
"start_time": "2026-05-02T11:57:10.602179Z"
}
},
"source": [
"op_banded_sym = (\n",
" nkdsl.SymbolicDiscreteJaxOperator(hi, \"banded\")\n",
" .for_each_site(\"i\")\n",
" .value_in_band(\"i\", lower=1, upper=2)\n",
" .where(nkdsl.site(\"i\") != 2)\n",
" .emit(nkdsl.identity(), matrix_element=1.0)\n",
" .build()\n",
")\n",
"op_banded = op_banded_sym.compile()\n",
"\n",
"for state in [\n",
" jnp.asarray([0, 1, 2, 1, 3], dtype=jnp.int32),\n",
" jnp.asarray([2, 2, 1, 0, 1], dtype=jnp.int32),\n",
"]:\n",
" _xp2, mels2 = op_banded.get_conn_padded(state)\n",
" print(\"x=\", state.tolist(), \"-> mels=\", mels2)\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x= [0, 1, 2, 1, 3] -> mels= [2. 0. 0. 0. 0.]\n",
"x= [2, 2, 1, 0, 1] -> mels= [2. 0. 0. 0. 0.]\n"
]
}
],
"execution_count": 5
},
{
"cell_type": "code",
"id": "ce3f024c2cbcc779",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:11.525841Z",
"start_time": "2026-05-02T11:57:11.520449Z"
}
},
"source": [
"print(op_banded_sym.to_ir())\n"
],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"symbolic.operator @\"banded\" [dtype=float64, hermitian=false, hilbert_size=5] {\n",
" ; 1 term(s)\n",
"\n",
" term #0 \"0\" [kbody, n_iter=5, max_conn_size=5] {\n",
" iterate: for (i,) in [(0,), (1,), (2,), ... +2 more]\n",
" where: (((x[i] >= 1) && (x[i] <= 2)) && (x[i] != 2))\n",
" emit #0:\n",
" update: identity\n",
" amplitude: 1\n",
" }\n",
"\n",
"}\n"
]
}
],
"execution_count": 6
},
{
"cell_type": "markdown",
"id": "f6e1c5f24ffb3741",
"metadata": {},
"source": [
"Reading tip: in the IR, look at `where:` for each term to confirm the final composed predicate.\n",
"If the textual predicate does not match your intent, the bug is usually in clause construction\n",
"(or in call ordering), not in lowering.\n"
]
},
{
"cell_type": "markdown",
"id": "9902d5837e72aae9",
"metadata": {},
"source": [
"## Design checklist for robust predicate clauses\n",
"\n",
"- Keep one clause focused on one concept (threshold, band, parity, membership, etc.).\n",
"- Use explicit names and argument defaults that read well in fluent chains.\n",
"- Validate parameters inside `build_predicate`.\n",
"- Return only boolean-coercible expression values.\n",
"- Print IR during development to verify composition behavior.\n",
"\n",
"If you follow those rules, custom predicates stay predictable and easy.\n"
]
},
{
"cell_type": "code",
"id": "95a2e1031ac690f0",
"metadata": {
"ExecuteTime": {
"end_time": "2026-05-02T11:57:11.529400Z",
"start_time": "2026-05-02T11:57:11.526365Z"
}
},
"source": [],
"outputs": [],
"execution_count": 6
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"pygments_lexer": "ipython3"
},
"mystnb": {
"execution_mode": "cache"
}
},
"nbformat": 4,
"nbformat_minor": 5
}