{ "cells": [ { "cell_type": "markdown", "id": "f334460775e9e1b4", "metadata": {}, "source": [ "# Extending Emission Clauses\n", "\n", "This notebook shows how to create new fluent emission methods by subclassing\n", "`nkdsl.AbstractEmissionClause`.\n", "\n", "Use this extension point when you want reusable branch templates (for example,\n", "\"emit with this guard and this standard tag/update policy\") so you do not repeat\n", "boilerplate `emit_if(...)` blocks everywhere.\n" ] }, { "cell_type": "markdown", "id": "7c6ecb3390612e0a", "metadata": {}, "source": [ "## Emission clause contract\n", "\n", "A custom emission clause should return `nkdsl.EmissionClauseSpec`.\n", "The key fields are:\n", "\n", "- `mode`: one of `emit`, `emit_if`, `emit_elseif`, `emit_else`\n", "- `predicate`: branch predicate (used by conditional modes)\n", "- `update`: rewrite program (`identity`, `shift`, `write`, chained updates, ...)\n", "- `matrix_element`: amplitude expression for that branch\n", "- `tag`: optional diagnostic tag\n", "\n", "In practice, most custom clauses are thin wrappers that package a repeated policy into one method.\n" ] }, { "cell_type": "code", "id": "5f721267e583a4dc", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:09.954507Z", "start_time": "2026-05-02T11:57:09.037958Z" } }, "source": [ "import jax.numpy as jnp\n", "import netket as nk\n", "import nkdsl\n", "\n", "hi = nk.hilbert.Fock(n_max=3, N=4)\n", "print(\"Hilbert size:\", hi.size)\n" ], "outputs": [ { "data": { "text/plain": [ "\u001B[1;36m∣NK⟩ Tip: \u001B[0mIf \u001B[33mtimeit\u001B[0m=\u001B[3;92mTrue\u001B[0m signals high \\% spent sampling n_discarded, consider lowering it.\n" ], "text/html": [ "
∣NK⟩ Tip: If timeit=True signals high \\% spent sampling n_discarded, consider lowering it.\n", "\n" ] }, "metadata": {}, "output_type": "display_data", "jetTransient": { "display_id": null } }, { "name": "stdout", "output_type": "stream", "text": [ "Hilbert size: 4\n" ] } ], "execution_count": 1 }, { "cell_type": "markdown", "id": "a6b3b48da845e543", "metadata": {}, "source": [ "## Example 1: `emit_when_at_least(...)`\n", "\n", "This clause emits only when a label's occupancy is at least a cutoff. It returns\n", "an `emit_if` spec so it can be chained with built-in `emit_elseif` or `emit_else`.\n" ] }, { "cell_type": "code", "id": "8de8ed193f88d5f", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:09.962329Z", "start_time": "2026-05-02T11:57:09.955843Z" } }, "source": [ "class EmitWhenAtLeast(nkdsl.AbstractEmissionClause):\n", " clause_name = \"emit_when_at_least\"\n", "\n", " def build_emission(self, ctx, label: str = \"i\", cutoff: int = 1):\n", " predicate = ctx.site(label).value >= int(cutoff)\n", " return nkdsl.EmissionClauseSpec(\n", " mode=\"emit_if\",\n", " predicate=predicate,\n", " update=nkdsl.identity(),\n", " matrix_element=ctx.site(label).value,\n", " tag=\"emit-when-at-least\",\n", " )\n", "\n", "\n", "nkdsl.register_emission_clause(EmitWhenAtLeast, replace=True)\n", "print(\"Registered:\", \"emit_when_at_least\" in nkdsl.available_emission_clause_names())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Registered: True\n" ] } ], "execution_count": 2 }, { "cell_type": "code", "id": "6d150acf10694959", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.777276Z", "start_time": "2026-05-02T11:57:09.962786Z" } }, "source": [ "op_custom_sym = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"custom-emission\")\n", " .for_each_site(\"i\")\n", " .emit_when_at_least(\"i\", cutoff=2)\n", " .emit_else(nkdsl.identity(), matrix_element=0.0, tag=\"fallback\")\n", " .build()\n", ")\n", "op_custom = op_custom_sym.compile()\n", "\n", "_xp, mels = op_custom.get_conn_padded(jnp.asarray([0, 1, 2, 3], dtype=jnp.int32))\n", "print(\"mels:\", mels)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mels: [5. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "execution_count": 3 }, { "cell_type": "code", "id": "b0a9601db755a4a7", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:11.808680Z", "start_time": "2026-05-02T11:57:11.804453Z" } }, "source": [ "print(op_custom_sym.to_ir())\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "symbolic.operator @\"custom-emission\" [dtype=float64, hermitian=false, hilbert_size=4] {\n", " ; 1 term(s)\n", "\n", " term #0 \"0\" [kbody, n_iter=4, max_conn_size=8] {\n", " iterate: for (i,) in [(0,), (1,), (2,), ... +1 more]\n", " where: true\n", " emit #0 [tag='emit-when-at-least']:\n", " where: (x[i] >= 2)\n", " update: identity\n", " amplitude: x[i]\n", " emit #1 [tag='fallback']:\n", " where: !(x[i] >= 2)\n", " update: identity\n", " amplitude: 0\n", " }\n", "\n", "}\n" ] } ], "execution_count": 4 }, { "cell_type": "markdown", "id": "5c0bbf80ec23348d", "metadata": {}, "source": [ "Notice how one custom fluent method captures both a predicate policy and default branch metadata.\n", "This is the main value proposition: you call one readable method, while internal behavior stays precise.\n" ] }, { "cell_type": "markdown", "id": "d32c27589d74d9ca", "metadata": {}, "source": [ "## Example 2: registration via generic `@register`\n", "\n", "The generic decorator can register iterator, predicate, or emission clauses.\n", "Here we use it for a second emission clause to show a compact authoring path.\n" ] }, { "cell_type": "code", "id": "c7fd9ee4da5588f6", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:12.634877Z", "start_time": "2026-05-02T11:57:11.809068Z" } }, "source": [ "@nkdsl.register\n", "class EmitIfNonZero(nkdsl.AbstractEmissionClause):\n", " clause_name = \"emit_if_nonzero\"\n", "\n", " def build_emission(self, ctx, label: str = \"i\", *, mel: float = 1.0):\n", " return nkdsl.EmissionClauseSpec(\n", " mode=\"emit_if\",\n", " predicate=ctx.site(label).value != 0,\n", " update=nkdsl.identity(),\n", " matrix_element=float(mel),\n", " tag=\"nonzero\",\n", " )\n", "\n", "\n", "print(\"Registered:\", \"emit_if_nonzero\" in nkdsl.available_emission_clause_names())\n", "\n", "op_nonzero_sym = (\n", " nkdsl.SymbolicDiscreteJaxOperator(hi, \"nonzero\")\n", " .for_each_site(\"i\")\n", " .emit_if_nonzero(\"i\", mel=5.0)\n", " .emit_else(nkdsl.identity(), matrix_element=0.0)\n", " .build()\n", ")\n", "op_nonzero = op_nonzero_sym.compile()\n", "\n", "_xp2, m2 = op_nonzero.get_conn_padded(jnp.asarray([0, 1, 0, 2], dtype=jnp.int32))\n", "print(\"m2:\", m2)\n" ], "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Registered: True\n", "m2: [20. 0. 0. 0. 0. 0. 0. 0.]\n" ] } ], "execution_count": 5 }, { "cell_type": "markdown", "id": "830432adda13aba0", "metadata": {}, "source": [ "## Design rules for production emission clauses\n", "\n", "- Keep one clause focused on one branch policy.\n", "- Validate arguments inside `build_emission`.\n", "- Choose `mode` deliberately and document it.\n", "- Include tags when they improve diagnostics.\n", "- Print IR during development to verify emitted branch structure.\n", "\n", "A good emission clause should make common logic shorter, clearer, and harder to misuse.\n" ] }, { "cell_type": "markdown", "id": "90b84553b4ca21fd", "metadata": {}, "source": [ "## Where this fits in the extension architecture\n", "\n", "With custom iterators, predicates, and emissions together, you can shape:\n", "\n", "- **where to iterate** (iterator clause)\n", "- **when to activate** (predicate clause)\n", "- **what to emit** (emission clause)\n", "\n", "That separation keeps DSL extensions composable and lets teams standardize domain-specific patterns\n", "without forking the compiler.\n" ] }, { "cell_type": "code", "id": "6f0d38b4b4f5485e", "metadata": { "ExecuteTime": { "end_time": "2026-05-02T11:57:12.643843Z", "start_time": "2026-05-02T11:57:12.640243Z" } }, "source": [], "outputs": [], "execution_count": 5 } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "pygments_lexer": "ipython3" }, "mystnb": { "execution_mode": "cache" } }, "nbformat": 4, "nbformat_minor": 5 }