{ "cells": [ { "cell_type": "markdown", "id": "4c38afd4169fe939", "metadata": {}, "source": [ "# Reading Symbolic IR\n", "\n", "This notebook is a deep walkthrough of how to read the symbolic IR emitted by the DSL.\n", "The objective is not just to print IR, but to turn IR text into a practical debugging tool.\n", "\n", "Because this notebook is executed during docs build, we rely on runtime output from `print(ir)`\n", "rather than embedding a static pre-copied IR dump.\n" ] }, { "cell_type": "markdown", "id": "7f26a503e7d291de", "metadata": {}, "source": [ "## What the IR tells you\n", "\n", "The textual IR is the closest compact representation of what the compiler will lower.\n", "It exposes three things clearly:\n", "\n", "1. iterator domains (`iterate:`)\n", "2. predicate logic (`where:` and branch-local conditions)\n", "3. emitted updates and amplitudes (`emit` blocks)\n", "\n", "If a compiled operator behaves unexpectedly, IR is usually the fastest place to locate the mismatch.\n" ] }, { "cell_type": "code", "id": "804032feca72a3b8", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.894053Z", "start_time": "2026-05-02T11:57:11.074446Z" } }, "source": [ "import netket as nk\n", "import nkdsl\n", "\n", "hi = nk.hilbert.Fock(n_max=2, N=4)\n", "\n", "op = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"ir-demo\", hermitian=True)\n", " .for_each_site(\"i\")\n", " .named(\"diagonal_count\")\n", " .emit(nkdsl.identity(), matrix_element=nkdsl.site(\"i\").value)\n", " .for_each_pair(\"i\", \"j\")\n", " .named(\"hopping\")\n", " .where(nkdsl.site(\"i\") > 0)\n", " .emit(nkdsl.shift(\"i\", -1).shift(\"j\", +1), matrix_element=1.0, tag=\"hop\")\n", " .build()\n", ")\n" ], "outputs": [ { "data": { "text/plain": [ "\u001B[1;36m∣NK⟩ Tip: \u001B[0m😍 Love? 😡 Hate? these tips? Let us know at \u001B[4;94mhttps://forms.gle/ej8eDGFRu2mww5mQ9\u001B[0m\n" ], "text/html": [ "
∣NK⟩ Tip: 😍 Love? 😡 Hate? these tips? Let us know at https://forms.gle/ej8eDGFRu2mww5mQ9\n",
       "
\n" ] }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } } ], "execution_count": 1 }, { "cell_type": "markdown", "id": "84c0bd38d2f03a3d", "metadata": {}, "source": [ "## 1) Print the IR\n", "\n", "Run the next cell and inspect the raw output first.\n", "Do this before looking at helper analyses so your own reading skill improves.\n" ] }, { "cell_type": "code", "id": "8dccda042765d566", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.911160Z", "start_time": "2026-05-02T11:57:11.895376Z" } }, "source": [ "ir = op.to_ir()\n", "print(ir)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "symbolic.operator @\"ir-demo\" [dtype=float64, hermitian=true, hilbert_size=4] {\n", " ; 2 term(s)\n", "\n", " term #0 \"diagonal_count\" [kbody, n_iter=4, max_conn_size=4] {\n", " iterate: for (i,) in [(0,), (1,), (2,), ... +1 more]\n", " where: true\n", " emit #0:\n", " update: identity\n", " amplitude: x[i]\n", " }\n", "\n", " term #1 \"hopping\" [kbody, n_iter=16, max_conn_size=16] {\n", " iterate: for (i, j) in [(0, 0), (0, 1), (0, 2), ... +13 more]\n", " where: (x[i] > 0)\n", " emit #0 [tag='hop']:\n", " update: x'[i] = (x[i] + -1); x'[j] = (x[j] + 1)\n", " amplitude: 1\n", " }\n", "\n", "}\n" ] } ], "execution_count": 2 }, { "cell_type": "markdown", "id": "af8377a8b5f2a584", "metadata": {}, "source": [ "## 2) General structure of this IR\n", "\n", "Read the printed IR as a hierarchy with three nested levels:\n", "\n", "1. Operator block\n", "This starts at `symbolic.operator ... {` and ends at the final `}`. It declares global metadata such as dtype, Hermitian flag, and Hilbert size.\n", "\n", "2. Term blocks\n", "Each `term #k \"name\" [...] { ... }` block is one logical contribution to the operator. The header tells you iterator type, number of iterator rows (`n_iter`), and a static upper bound for branch count (`max_conn_size`).\n", "\n", "3. Emission blocks inside each term\n", "Each `emit #m ...:` block describes one branch: a rewrite (`update`) and its matrix element (`amplitude`). For conditional branches, an additional branch-local `where:` appears in that block.\n", "\n", "Inside each term, read in this order every time:\n", "`iterate` -> `where` -> each `emit` block (`update` + `amplitude`).\n", "\n", "That order mirrors how the DSL is interpreted and lowered, so it is the fastest way to debug.\n" ] }, { "cell_type": "markdown", "id": "4719984f80cf135", "metadata": {}, "source": [ "## 3) Line-by-line walkthrough of the printed output above\n", "\n", "Below is the exact interpretation of each printed line in this notebook run.\n", "\n", "1. `symbolic.operator @\"ir-demo\" [dtype=float64, hermitian=true, hilbert_size=4] {`\n", "Opens the operator block and declares global metadata: symbolic name, scalar dtype, Hermitian declaration, and Hilbert size.\n", "\n", "2. ` ; 2 term(s)`\n", "A summary comment telling you this operator contains exactly two terms.\n", "\n", "3. ` term #0 \"diagonal_count\" [kbody, n_iter=4, max_conn_size=4] {`\n", "Opens term 0. This term iterates over 4 rows and has a static branch bound of 4.\n", "\n", "4. ` iterate: for (i,) in [(0,), (1,), (2,), ... +1 more]`\n", "Iterator domain preview for term 0: one label (`i`) over all four sites.\n", "\n", "5. ` where: true`\n", "Term-level predicate is unconditional; every iterator row is eligible.\n", "\n", "6. ` emit #0:`\n", "Term 0 has one emission branch.\n", "\n", "7. ` update: identity`\n", "This branch does not rewrite the state.\n", "\n", "8. ` amplitude: x[i]`\n", "Matrix element for this branch is the occupancy/value at site label `i`.\n", "\n", "9. ` }`\n", "Closes term 0.\n", "\n", "10. ` term #1 \"hopping\" [kbody, n_iter=16, max_conn_size=16] {`\n", "Opens term 1. This term iterates over 16 ordered `(i, j)` rows for a 4-site system.\n", "\n", "11. ` iterate: for (i, j) in [(0, 0), (0, 1), (0, 2), ... +13 more]`\n", "Iterator domain preview for term 1: pair labels `(i, j)` over all ordered pairs.\n", "\n", "12. ` where: (x[i] > 0)`\n", "Term-level activation guard: hopping only applies when source site `i` is occupied/positive.\n", "\n", "13. ` emit #0 [tag='hop']:`\n", "One emission branch for this term, with diagnostic tag `hop`.\n", "\n", "14. ` update: x'[i] = (x[i] + -1); x'[j] = (x[j] + 1)`\n", "Rewrite rule for hopping: remove one unit at `i`, add one unit at `j`.\n", "\n", "15. ` amplitude: 1`\n", "Constant matrix element for the hop branch.\n", "\n", "16. ` }`\n", "Closes term 1.\n", "\n", "17. `}`\n", "Closes the full operator block.\n", "\n", "Practical reading rule: if behavior is wrong, check lines in this order first: `iterate`, then `where`, then each `emit` (`update`/`amplitude`).\n" ] }, { "cell_type": "markdown", "id": "3263713aaf3a1b44", "metadata": {}, "source": [ "## 4) Structured inspection with `as_dict()`\n", "\n", "Text is ideal for humans, while structured payloads are better for tooling and assertions.\n", "Use `as_dict()` when writing diagnostics, tests, or metadata checks.\n" ] }, { "cell_type": "code", "id": "373eb70f92e82529", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.927816Z", "start_time": "2026-05-02T11:57:11.911673Z" } }, "source": [ "payload = ir.as_dict()\n", "print(\"top-level keys:\", sorted(payload.keys()))\n", "print(\"number of terms:\", len(payload[\"terms\"]))\n", "print(\"first term keys:\", sorted(payload[\"terms\"][0].keys()))\n", "print(\"first term name:\", payload[\"terms\"][0][\"name\"])\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "top-level keys: ['dtype_str', 'hilbert_size', 'is_hermitian', 'metadata', 'mode', 'operator_name', 'terms']\n", "number of terms: 2\n", "first term keys: ['amplitude', 'branch_tag', 'emissions', 'iterator', 'max_conn_size_hint', 'metadata', 'name', 'predicate', 'update_program']\n", "first term name: diagonal_count\n" ] } ], "execution_count": 3 }, { "cell_type": "markdown", "id": "55df7669ceb18600", "metadata": {}, "source": [ "## 5) Free symbols and static fingerprint\n", "\n", "Two especially useful IR-level diagnostics:\n", "\n", "- `free_symbols`: external symbolic parameters required by this operator\n", "- `static_fingerprint()`: structural identity hash for cache keys and change detection\n" ] }, { "cell_type": "code", "id": "9aadc5614ce63194", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.944407Z", "start_time": "2026-05-02T11:57:11.928778Z" } }, "source": [ "op_with_symbol = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"with-symbol\")\n", " .for_each_site(\"i\")\n", " .emit(nkdsl.identity(), matrix_element=nkdsl.symbol(\"J\") * nkdsl.site(\"i\").value)\n", " .build()\n", ")\n", "\n", "ir_sym = op_with_symbol.to_ir()\n", "print(\"free symbols:\", sorted(ir_sym.free_symbols))\n", "print(\"fingerprint:\", ir_sym.static_fingerprint()[:24], \"...\")\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "free symbols: ['J']\n", "fingerprint: d5c4c317a5b3a888d67c7968 ...\n" ] } ], "execution_count": 4 }, { "cell_type": "markdown", "id": "ca26ac80ccb47e2b", "metadata": {}, "source": [ "## 6) Read IR for conditional emissions\n", "\n", "Conditional emissions introduce branch-local predicates.\n", "Printing this IR helps verify that `if / elseif / else` logic was lowered as intended.\n" ] }, { "cell_type": "code", "id": "ab6e0df2d90f911", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.962741Z", "start_time": "2026-05-02T11:57:11.946428Z" } }, "source": [ "op_cond = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"cond-ir\")\n", " .for_each_site(\"i\")\n", " .emit_if(nkdsl.site(\"i\") == 0, nkdsl.write(\"i\", 1), matrix_element=10.0, tag=\"if\")\n", " .emit_elseif(nkdsl.site(\"i\") == 1, nkdsl.write(\"i\", 2), matrix_element=20.0, tag=\"elseif\")\n", " .emit_else(nkdsl.write(\"i\", 0), matrix_element=30.0, tag=\"else\")\n", " .build()\n", ")\n", "\n", "print(op_cond.to_ir())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "symbolic.operator @\"cond-ir\" [dtype=float64, hermitian=false, hilbert_size=4] {\n", " ; 1 term(s)\n", "\n", " term #0 \"0\" [kbody, n_iter=4, max_conn_size=12] {\n", " iterate: for (i,) in [(0,), (1,), (2,), ... +1 more]\n", " where: true\n", " emit #0 [tag='if']:\n", " where: (x[i] == 0)\n", " update: x'[i] = 1\n", " amplitude: 10\n", " emit #1 [tag='elseif']:\n", " where: (!(x[i] == 0) && (x[i] == 1))\n", " update: x'[i] = 2\n", " amplitude: 20\n", " emit #2 [tag='else']:\n", " where: (!(x[i] == 0) && !(x[i] == 1))\n", " update: x'[i] = 0\n", " amplitude: 30\n", " }\n", "\n", "}\n" ] } ], "execution_count": 5 }, { "cell_type": "markdown", "id": "1fd1245437c070eb", "metadata": {}, "source": [ "## Closing advice\n", "\n", "Treat IR reading as a first-class engineering skill. In day-to-day work, a fast IR check often\n", "resolves issues before you need deep runtime debugging.\n", "\n", "If behavior looks wrong, ask in this order:\n", "\n", "1. Is the iterator domain correct?\n", "2. Is predicate composition correct?\n", "3. Are emission updates/amplitudes correct?\n", "\n", "That order aligns with the compiler pipeline and usually finds the issue quickly.\n" ] }, { "cell_type": "code", "id": "69923e4aa94c7e72", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.966235Z", "start_time": "2026-05-02T11:57:11.963386Z" } }, "source": [], "outputs": [], "execution_count": 5 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" }, "mystnb": { "execution_mode": "cache" } }, "nbformat": 4, "nbformat_minor": 5 }