{ "cells": [ { "cell_type": "markdown", "id": "640cb2bdce67415e", "metadata": {}, "source": [ "# Custom Iterators\n", "\n", "This notebook is about extending the iterator layer of the DSL, not just using it.\n", "A custom iterator clause is the piece that decides **which site tuples are visited**\n", "before predicates and emissions run. If you can control iteration, you can encode\n", "many model families without touching compiler internals.\n", "\n", "The goal here is to build an intuition for the contract, then implement two realistic\n", "iterator clauses from scratch and use them in compiled operators.\n" ] }, { "cell_type": "markdown", "id": "6a72f5bbe15512b9", "metadata": {}, "source": [ "## Before we start: mental model and contract\n", "\n", "When you call an iterator method (for example `for_each_site`), the builder opens a term.\n", "Every later predicate and emission in that term is evaluated over the rows emitted by that iterator.\n", "\n", "For custom iterator clauses, the core contract is:\n", "\n", "1. subclass `nkdsl.AbstractIteratorClause`\n", "2. implement `build_iterator(self, hilbert, *args, **kwargs)`\n", "3. return data coercible to a `KBodyIteratorSpec`\n", "\n", "In practice, the easiest return form is:\n", "\n", "- `labels`: a tuple of string labels, such as `(\"i\",)` or `(\"i\", \"j\")`\n", "- `rows`: a tuple of integer tuples, one tuple per iterator row\n", "\n", "The DSL enforces arity consistency (`len(row) == len(labels)`) and expects at least one row.\n" ] }, { "cell_type": "code", "id": "c61c1066e8e50cd9", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.600552Z", "start_time": "2026-05-02T11:57:03.791321Z" } }, "source": [ "import jax.numpy as jnp\n", "import netket as nk\n", "import nkdsl\n", "\n", "hi = nk.hilbert.Fock(n_max=3, N=6)\n", "print(\"Hilbert size:\", hi.size)\n", "print(\"Built-in iterator clauses (sample):\", sorted(nkdsl.available_iterator_clause_names())[:8])\n" ], "outputs": [ { "data": { "text/plain": [ "\u001B[1;36m∣NK⟩ Tip: \u001B[0mYou must cite NetKet according to our policy. Use \u001B[1;35mnk.cite\u001B[0m\u001B[1m(\u001B[0m\u001B[1m)\u001B[0m to find out how.\n" ], "text/html": [ "
∣NK⟩ Tip: You must cite NetKet according to our policy. Use nk.cite() to find out how.\n",
       "
\n" ] }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } }, { "name": "stdout", "output_type": "stream", "text": [ "Hilbert size: 6\n", "Built-in iterator clauses (sample): ['for_each', 'for_each_distinct_pair', 'for_each_pair', 'for_each_plaquette', 'for_each_site', 'for_each_triplet', 'globally']\n" ] } ], "execution_count": 1 }, { "cell_type": "markdown", "id": "b60a53d52bdfd395", "metadata": {}, "source": [ "## Example 1: iterate only over even sites\n", "\n", "This is a simple but very common pattern: select a deterministic subset of sites.\n", "In larger projects, this can represent a sublattice, a gauge partition, or a domain decomposition.\n", "\n", "Notice two design choices in the implementation below:\n", "\n", "- The clause validates that the generated row list is not empty.\n", "- It takes an optional label argument so you can compose it naturally with existing DSL style.\n" ] }, { "cell_type": "code", "id": "c4843e4f3762cc92", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.658481Z", "start_time": "2026-05-02T11:57:04.610165Z" } }, "source": [ "class ForEachEvenSite(nkdsl.AbstractIteratorClause):\n", " clause_name = \"for_each_even_site\"\n", "\n", " def build_iterator(self, hilbert, label: str = \"i\"):\n", " n = int(hilbert.size)\n", " rows = tuple((k,) for k in range(n) if k % 2 == 0)\n", " if not rows:\n", " raise ValueError(\"No even sites are available for this Hilbert space.\")\n", " return (str(label),), rows\n", "\n", "\n", "nkdsl.register_iterator_clause(ForEachEvenSite, replace=True)\n", "print(\"Registered:\", \"for_each_even_site\" in nkdsl.available_iterator_clause_names())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Registered: True\n" ] } ], "execution_count": 2 }, { "cell_type": "markdown", "id": "ddaced6e2889e0c8", "metadata": {}, "source": [ "### Why this return value works\n", "\n", "`return (\"i\",), ((0,), (2,), (4,), ...)` is accepted because iterator coercion supports\n", "`(labels, over)` as a first-class format.\n", "\n", "At compile time, those rows become the static iteration domain for the term.\n", "That means your custom iterator directly controls both term cardinality and the static\n", "padding bound (`max_conn_size`) for the generated operator.\n" ] }, { "cell_type": "code", "id": "7eea30c2bc3cbeaf", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:05.890082Z", "start_time": "2026-05-02T11:57:04.661588Z" } }, "source": [ "op_even = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"diag-even\")\n", " .for_each_even_site(\"i\")\n", " .emit(nkdsl.identity(), matrix_element=nkdsl.site(\"i\").value)\n", " .build()\n", " .compile()\n", ")\n", "\n", "x = jnp.asarray([0, 2, 1, 3, 2, 0], dtype=jnp.int32)\n", "xp, mels = op_even.get_conn_padded(x)\n", "print(\"xp shape:\", xp.shape)\n", "print(\"mels:\", mels)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "xp shape: (3, 6)\n", "mels: [3. 0. 0.]\n" ] } ], "execution_count": 3 }, { "cell_type": "markdown", "id": "15ecafe0312e6697", "metadata": {}, "source": [ "The nonzero matrix elements correspond to iterator rows that actually exist in the clause.\n", "Since this iterator visits only even indices, odd-site contributions never appear in this term.\n", "\n", "That is an important distinction:\n", "\n", "- a predicate filters rows **after** iteration\n", "- a custom iterator changes the domain **before** predicates even run\n" ] }, { "cell_type": "markdown", "id": "4e5115607944543a", "metadata": {}, "source": [ "## Example 2: edge-driven iteration with validation\n", "\n", "For graph-like models, you often want iteration over a fixed edge list. A custom iterator\n", "is the right layer for that because edge structure is a domain definition, not a boolean filter.\n", "\n", "Below we add defensive checks:\n", "\n", "- at least one edge must be present\n", "- each endpoint index must be within `[0, hilbert.size)`\n", "\n", "This keeps errors close to input and makes broken graphs fail fast.\n" ] }, { "cell_type": "code", "id": "174128a3f08674e3", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:05.917287Z", "start_time": "2026-05-02T11:57:05.913903Z" } }, "source": [ "class ForEachEdge(nkdsl.AbstractIteratorClause):\n", " clause_name = \"for_each_edge\"\n", "\n", " def build_iterator(self, hilbert, src: str = \"i\", dst: str = \"j\", *, edges):\n", " n = int(hilbert.size)\n", " rows = tuple((int(i), int(j)) for i, j in edges)\n", " if not rows:\n", " raise ValueError(\"edges must contain at least one pair.\")\n", " for i, j in rows:\n", " if i < 0 or j < 0 or i >= n or j >= n:\n", " raise ValueError(f\"edge ({i}, {j}) is out of bounds for hilbert size {n}.\")\n", " return (str(src), str(dst)), rows\n", "\n", "\n", "nkdsl.register_iterator_clause(ForEachEdge, replace=True)\n", "print(\"Registered:\", \"for_each_edge\" in nkdsl.available_iterator_clause_names())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Registered: True\n" ] } ], "execution_count": 4 }, { "cell_type": "code", "id": "e50830908069b9e4", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:06.858397Z", "start_time": "2026-05-02T11:57:05.917573Z" } }, "source": [ "edges = [(0, 1), (1, 2), (2, 3), (4, 5)]\n", "\n", "op_hop_sym = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"edge-hop\")\n", " .for_each_edge(\"i\", \"j\", edges=edges)\n", " .where((nkdsl.site(\"i\") > 0) & (nkdsl.site(\"j\") < 3))\n", " .emit(nkdsl.shift(\"i\", -1).shift(\"j\", +1), matrix_element=1.0)\n", " .build()\n", ")\n", "op_hop = op_hop_sym.compile()\n", "\n", "x_hop = jnp.asarray([1, 0, 1, 0, 2, 0], dtype=jnp.int32)\n", "xp_hop, mels_hop = op_hop.get_conn_padded(x_hop)\n", "print(\"xp_hop shape:\", xp_hop.shape)\n", "print(\"mels_hop:\", mels_hop)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "xp_hop shape: (4, 6)\n", "mels_hop: [1. 1. 1. 0.]\n" ] } ], "execution_count": 5 }, { "cell_type": "code", "id": "fbb5a0b0f2c33c2c", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:06.865344Z", "start_time": "2026-05-02T11:57:06.862463Z" } }, "source": [ "print(op_hop_sym.to_ir())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "symbolic.operator @\"edge-hop\" [dtype=float64, hermitian=false, hilbert_size=6] {\n", " ; 1 term(s)\n", "\n", " term #0 \"0\" [kbody, n_iter=4, max_conn_size=4] {\n", " iterate: for (i, j) in [(0, 1), (1, 2), (2, 3), ... +1 more]\n", " where: ((x[i] > 0) && (x[j] < 3))\n", " emit #0:\n", " update: x'[i] = (x[i] + -1); x'[j] = (x[j] + 1)\n", " amplitude: 1\n", " }\n", "\n", "}\n" ] } ], "execution_count": 6 }, { "cell_type": "markdown", "id": "e32054441f8dcabe", "metadata": {}, "source": [ "In the IR, inspect the `iterate:` line first. If the iterator rows are wrong, all downstream\n", "predicate and emission behavior will look wrong too, even if their logic is correct.\n", "\n", "For debugging custom iterators, this order usually saves time:\n", "\n", "1. verify labels and row arity in the clause\n", "2. print IR and confirm the iterator domain\n", "3. only then debug predicate/emission behavior\n" ] }, { "cell_type": "markdown", "id": "97a5173a5ad68a0b", "metadata": {}, "source": [ "## Practical guidelines for production clauses\n", "\n", "Keep each iterator clause focused on one selection rule. If a clause needs many optional flags,\n", "it is usually better to split it into two or three explicit clauses with clearer intent.\n", "\n", "Prefer deterministic, fully static row generation from builder inputs. The compiler assumes a\n", "static branch shape; dynamic row counts tied to runtime state are not what iterator clauses are for.\n", "\n", "Finally, add argument validation in the clause itself. It is much easier for you to fix a clear\n", "`ValueError` from `build_iterator` than to diagnose a silent semantic mismatch several layers later.\n" ] }, { "cell_type": "code", "id": "2d2a0e9727bd59b3", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:06.868120Z", "start_time": "2026-05-02T11:57:06.866Z" } }, "source": [], "outputs": [], "execution_count": 6 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" }, "mystnb": { "execution_mode": "cache" } }, "nbformat": 4, "nbformat_minor": 5 }