{ "cells": [ { "cell_type": "markdown", "id": "1415e8ba7fb81f16", "metadata": {}, "source": [ "# Conditional Emissions (`if / elseif / else`)\n", "\n", "This notebook explains piecewise emission logic in the DSL.\n", "Conditional emissions let one term choose among multiple branch updates while preserving\n", "static output shape, which is critical for JAX-friendly lowering.\n", "\n", "We will cover semantics, execution behavior, IR inspection, and common mistakes.\n" ] }, { "cell_type": "markdown", "id": "4560d95b5ca9e6f9", "metadata": {}, "source": [ "## Why conditional emissions exist\n", "\n", "Many physical rules are naturally piecewise:\n", "\n", "- if local occupancy is zero, do one transition\n", "- else if occupancy is one, do another\n", "- else apply a fallback\n", "\n", "Historically you could encode this with multiple terms and predicates. Conditional emissions keep\n", "that logic in one cohesive term, improve readability, and retain deterministic branch accounting.\n" ] }, { "cell_type": "code", "id": "30910a32a197b0d1", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:03.192986Z", "start_time": "2026-05-02T11:57:01.962642Z" } }, "source": [ "import jax.numpy as jnp\n", "import netket as nk\n", "import nkdsl\n", "\n", "hi = nk.hilbert.Fock(n_max=4, N=1)\n", "print(\"Hilbert size:\", hi.size)\n" ], "outputs": [ { "data": { "text/plain": [ "\u001B[1;36m∣NK⟩ Tip: \u001B[0mUse \u001B[1;35mdriver.run\u001B[0m\u001B[1m(\u001B[0m\u001B[33m...\u001B[0m, \u001B[33mtimeit\u001B[0m=\u001B[3;92mTrue\u001B[0m\u001B[1m)\u001B[0m to know where your dominant cost is.\n" ], "text/html": [ "
∣NK⟩ Tip: Use driver.run(..., timeit=True) to know where your dominant cost is.\n", "\n" ] }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } }, { "name": "stdout", "output_type": "stream", "text": [ "Hilbert size: 1\n" ] } ], "execution_count": 1 }, { "cell_type": "markdown", "id": "cb7e1764354f33e3", "metadata": {}, "source": [ "## 1) Build a canonical `if / elseif / else` chain\n", "\n", "Read this chain like ordinary control flow, but remember that the compiler still allocates\n", "static branch slots. Runtime predicates decide which slot carries nonzero contribution.\n" ] }, { "cell_type": "code", "id": "42fda4ff3f994d6a", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:03.786890Z", "start_time": "2026-05-02T11:57:03.194095Z" } }, "source": [ "op_sym = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"piecewise\")\n", " .for_each_site(\"i\")\n", " .emit_if(nkdsl.site(\"i\") == 0, nkdsl.write(\"i\", 1), matrix_element=10.0, tag=\"if\")\n", " .emit_elseif(nkdsl.site(\"i\") == 1, nkdsl.write(\"i\", 2), matrix_element=20.0, tag=\"elseif\")\n", " .emit_else(nkdsl.write(\"i\", 3), matrix_element=30.0, tag=\"else\")\n", " .build()\n", ")\n", "op = op_sym.compile()\n" ], "outputs": [], "execution_count": 2 }, { "cell_type": "code", "id": "62744c8fb8f70d2c", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:03.820748Z", "start_time": "2026-05-02T11:57:03.816230Z" } }, "source": [ "print(op_sym.to_ir())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "symbolic.operator @\"piecewise\" [dtype=float64, hermitian=false, hilbert_size=1] {\n", " ; 1 term(s)\n", "\n", " term #0 \"0\" [kbody, n_iter=1, max_conn_size=3] {\n", " iterate: for (i,) in [(0,)]\n", " where: true\n", " emit #0 [tag='if']:\n", " where: (x[i] == 0)\n", " update: x'[i] = 1\n", " amplitude: 10\n", " emit #1 [tag='elseif']:\n", " where: (!(x[i] == 0) && (x[i] == 1))\n", " update: x'[i] = 2\n", " amplitude: 20\n", " emit #2 [tag='else']:\n", " where: (!(x[i] == 0) && !(x[i] == 1))\n", " update: x'[i] = 3\n", " amplitude: 30\n", " }\n", "\n", "}\n" ] } ], "execution_count": 3 }, { "cell_type": "markdown", "id": "9f330ff69bf32a6f", "metadata": {}, "source": [ "In the IR, inspect each `emit` block and check that branch predicates represent the intended chain:\n", "\n", "- first branch is exactly the `if` predicate\n", "- each `elseif` is guarded by \"previous branches were false\" AND its own predicate\n", "- `else` is guarded by the negation of all prior branch predicates\n", "\n", "That transformation is what enforces mutually exclusive behavior.\n" ] }, { "cell_type": "markdown", "id": "39e7625753c3dce1", "metadata": {}, "source": [ "## 2) Evaluate behavior across representative inputs\n", "\n", "Because this Hilbert has one site (`N=1`), each source configuration is just a scalar occupancy.\n", "We test several occupancies and observe which branch becomes active.\n" ] }, { "cell_type": "code", "id": "c0928edcdd1f7cd7", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:03.916634Z", "start_time": "2026-05-02T11:57:03.821318Z" } }, "source": [ "for n in [0, 1, 2, 3]:\n", " x = jnp.asarray([n], dtype=jnp.int32)\n", " xp, mels = op.get_conn_padded(x)\n", " print(f\"x={n} -> xp={xp.reshape(-1)}, mels={mels}\")\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x=0 -> xp=[1 0 0], mels=[10. 0. 0.]\n", "x=1 -> xp=[2 0 0], mels=[20. 0. 0.]\n", "x=2 -> xp=[3 0 0], mels=[30. 0. 0.]\n", "x=3 -> xp=[3 0 0], mels=[30. 0. 0.]\n" ] } ], "execution_count": 4 }, { "cell_type": "markdown", "id": "b517ba78c5423668", "metadata": {}, "source": [ "Interpretation strategy:\n", "\n", "1. identify nonzero matrix-element slot(s)\n", "2. map that slot back to the branch order in your chain\n", "3. verify rewritten state and amplitude together\n", "\n", "If two slots are nonzero for a supposedly exclusive chain, the predicate logic is likely incorrect.\n" ] }, { "cell_type": "markdown", "id": "2e5bf4e12a84ed0c", "metadata": {}, "source": [ "## 3) Multiple `elseif` branches\n", "\n", "Longer chains are supported. This is useful when discretized local state values map to different\n", "transition rules or coefficients.\n" ] }, { "cell_type": "code", "id": "cf09212cbb583a5b", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.413626Z", "start_time": "2026-05-02T11:57:03.917034Z" } }, "source": [ "op_multi = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"piecewise-4\")\n", " .for_each_site(\"i\")\n", " .emit_if(nkdsl.site(\"i\") == 0, nkdsl.write(\"i\", 1), matrix_element=1.0)\n", " .emit_elseif(nkdsl.site(\"i\") == 1, nkdsl.write(\"i\", 2), matrix_element=2.0)\n", " .emit_elseif(nkdsl.site(\"i\") == 2, nkdsl.write(\"i\", 3), matrix_element=3.0)\n", " .emit_else(nkdsl.write(\"i\", 4), matrix_element=4.0)\n", " .build()\n", " .compile()\n", ")\n", "\n", "for n in [0, 1, 2, 3, 4]:\n", " _xp2, mels2 = op_multi.get_conn_padded(jnp.asarray([n], dtype=jnp.int32))\n", " print(f\"n={n} -> mels={mels2}\")\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "n=0 -> mels=[1. 0. 0. 0.]\n", "n=1 -> mels=[2. 0. 0. 0.]\n", "n=2 -> mels=[3. 0. 0. 0.]\n", "n=3 -> mels=[4. 0. 0. 0.]\n", "n=4 -> mels=[4. 0. 0. 0.]\n" ] } ], "execution_count": 5 }, { "cell_type": "markdown", "id": "3c2f7370540814cb", "metadata": {}, "source": [ "## 4) Rules and error cases\n", "\n", "Conditional chaining has intentional structural rules to prevent ambiguous semantics.\n", "`emit_elseif(...)` and `emit_else(...)` are only valid when an `emit_if(...)` chain is currently open.\n", "\n", "The next cell demonstrates one failure mode so you can recognize the error quickly.\n" ] }, { "cell_type": "code", "id": "c6fe3881da647c7b", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.417002Z", "start_time": "2026-05-02T11:57:04.414034Z" } }, "source": [ "try:\n", " (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"bad-chain\")\n", " .for_each_site(\"i\")\n", " .emit_elseif(nkdsl.site(\"i\") == 0, nkdsl.identity(), matrix_element=1.0)\n", " )\n", "except Exception as exc:\n", " print(type(exc).__name__, \"->\", exc)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "ValueError -> .emit_elseif() must follow .emit_if(...) or .emit_elseif(...) without intervening term modifiers.\n" ] } ], "execution_count": 6 }, { "cell_type": "markdown", "id": "182e066439557321", "metadata": {}, "source": [ "## 5) Combining term-level and branch-level conditions\n", "\n", "A term-level `where(...)` remains a coarse gate, and branch predicates refine behavior inside that gate.\n", "This is often the cleanest way to encode hard constraints plus piecewise transitions.\n" ] }, { "cell_type": "code", "id": "f8fd46799fe1fd2c", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.682505Z", "start_time": "2026-05-02T11:57:04.417302Z" } }, "source": [ "op_guarded = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"guarded-piecewise\")\n", " .for_each_site(\"i\")\n", " .where(nkdsl.site(\"i\") >= 1)\n", " .emit_if(nkdsl.site(\"i\") == 1, nkdsl.write(\"i\", 2), matrix_element=11.0)\n", " .emit_else(nkdsl.write(\"i\", 3), matrix_element=33.0)\n", " .build()\n", " .compile()\n", ")\n", "\n", "for n in [0, 1, 2]:\n", " _xp3, mels3 = op_guarded.get_conn_padded(jnp.asarray([n], dtype=jnp.int32))\n", " print(f\"guarded n={n} -> mels={mels3}\")\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "guarded n=0 -> mels=[0. 0.]\n", "guarded n=1 -> mels=[11. 0.]\n", "guarded n=2 -> mels=[33. 0.]\n" ] } ], "execution_count": 7 }, { "cell_type": "markdown", "id": "e3d89ca7500bba5a", "metadata": {}, "source": [ "## Practical recommendations\n", "\n", "- Keep branch order intentional and document it in term names/tags.\n", "- Use small representative input probes to validate exclusivity.\n", "- Inspect IR to confirm branch predicates after lowering.\n", "- Prefer one clear conditional term over several partially overlapping terms when logic is piecewise.\n", "\n", "This keeps both user-facing DSL and compiler-facing IR easier to reason about.\n" ] }, { "cell_type": "code", "id": "ba4642ce42560848", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:04.717191Z", "start_time": "2026-05-02T11:57:04.684432Z" } }, "source": [], "outputs": [], "execution_count": 7 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" }, "mystnb": { "execution_mode": "cache" } }, "nbformat": 4, "nbformat_minor": 5 }