Conditional Emissions (if / elseif / else)

This notebook explains piecewise emission logic in the DSL. Conditional emissions let one term choose among multiple branch updates while preserving static output shape, which is critical for JAX-friendly lowering.

We will cover semantics, execution behavior, IR inspection, and common mistakes.

Why conditional emissions exist

Many physical rules are naturally piecewise:

  • if local occupancy is zero, do one transition

  • else if occupancy is one, do another

  • else apply a fallback

Historically you could encode this with multiple terms and predicates. Conditional emissions keep that logic in one cohesive term, improve readability, and retain deterministic branch accounting.

import jax.numpy as jnp
import netket as nk
import nkdsl

hi = nk.hilbert.Fock(n_max=4, N=1)
print("Hilbert size:", hi.size)
∣NK⟩ Tip: Use driver.run(..., timeit=True) to know where your dominant cost is.
Hilbert size: 1

1) Build a canonical if / elseif / else chain

Read this chain like ordinary control flow, but remember that the compiler still allocates static branch slots. Runtime predicates decide which slot carries nonzero contribution.

op_sym = (
    nkdsl.SymbolicDiscreteJaxOperator(hi, "piecewise")
    .for_each_site("i")
    .emit_if(nkdsl.site("i") == 0, nkdsl.write("i", 1), matrix_element=10.0, tag="if")
    .emit_elseif(nkdsl.site("i") == 1, nkdsl.write("i", 2), matrix_element=20.0, tag="elseif")
    .emit_else(nkdsl.write("i", 3), matrix_element=30.0, tag="else")
    .build()
)
op = op_sym.compile()
print(op_sym.to_ir())
symbolic.operator @"piecewise" [dtype=float64, hermitian=false, hilbert_size=1] {
  ; 1 term(s)

  term #0 "0" [kbody, n_iter=1, max_conn_size=3] {
    iterate: for (i,) in [(0,)]
    where:   true
    emit #0 [tag='if']:
      where:     (x[i] == 0)
      update:    x'[i] = 1
      amplitude: 10
    emit #1 [tag='elseif']:
      where:     (!(x[i] == 0) && (x[i] == 1))
      update:    x'[i] = 2
      amplitude: 20
    emit #2 [tag='else']:
      where:     (!(x[i] == 0) && !(x[i] == 1))
      update:    x'[i] = 3
      amplitude: 30
  }

}

In the IR, inspect each emit block and check that branch predicates represent the intended chain:

  • first branch is exactly the if predicate

  • each elseif is guarded by “previous branches were false” AND its own predicate

  • else is guarded by the negation of all prior branch predicates

That transformation is what enforces mutually exclusive behavior.

2) Evaluate behavior across representative inputs

Because this Hilbert has one site (N=1), each source configuration is just a scalar occupancy. We test several occupancies and observe which branch becomes active.

for n in [0, 1, 2, 3]:
    x = jnp.asarray([n], dtype=jnp.int32)
    xp, mels = op.get_conn_padded(x)
    print(f"x={n} -> xp={xp.reshape(-1)}, mels={mels}")
x=0 -> xp=[1 0 0], mels=[10.  0.  0.]
x=1 -> xp=[2 0 0], mels=[20.  0.  0.]
x=2 -> xp=[3 0 0], mels=[30.  0.  0.]
x=3 -> xp=[3 0 0], mels=[30.  0.  0.]

Interpretation strategy:

  1. identify nonzero matrix-element slot(s)

  2. map that slot back to the branch order in your chain

  3. verify rewritten state and amplitude together

If two slots are nonzero for a supposedly exclusive chain, the predicate logic is likely incorrect.

3) Multiple elseif branches

Longer chains are supported. This is useful when discretized local state values map to different transition rules or coefficients.

op_multi = (
    nkdsl.SymbolicDiscreteJaxOperator(hi, "piecewise-4")
    .for_each_site("i")
    .emit_if(nkdsl.site("i") == 0, nkdsl.write("i", 1), matrix_element=1.0)
    .emit_elseif(nkdsl.site("i") == 1, nkdsl.write("i", 2), matrix_element=2.0)
    .emit_elseif(nkdsl.site("i") == 2, nkdsl.write("i", 3), matrix_element=3.0)
    .emit_else(nkdsl.write("i", 4), matrix_element=4.0)
    .build()
    .compile()
)

for n in [0, 1, 2, 3, 4]:
    _xp2, mels2 = op_multi.get_conn_padded(jnp.asarray([n], dtype=jnp.int32))
    print(f"n={n} -> mels={mels2}")
n=0 -> mels=[1. 0. 0. 0.]
n=1 -> mels=[2. 0. 0. 0.]
n=2 -> mels=[3. 0. 0. 0.]
n=3 -> mels=[4. 0. 0. 0.]
n=4 -> mels=[4. 0. 0. 0.]

4) Rules and error cases

Conditional chaining has intentional structural rules to prevent ambiguous semantics. emit_elseif(...) and emit_else(...) are only valid when an emit_if(...) chain is currently open.

The next cell demonstrates one failure mode so you can recognize the error quickly.

try:
    (
        nkdsl.SymbolicDiscreteJaxOperator(hi, "bad-chain")
        .for_each_site("i")
        .emit_elseif(nkdsl.site("i") == 0, nkdsl.identity(), matrix_element=1.0)
    )
except Exception as exc:
    print(type(exc).__name__, "->", exc)
ValueError -> .emit_elseif() must follow .emit_if(...) or .emit_elseif(...) without intervening term modifiers.

5) Combining term-level and branch-level conditions

A term-level where(...) remains a coarse gate, and branch predicates refine behavior inside that gate. This is often the cleanest way to encode hard constraints plus piecewise transitions.

op_guarded = (
    nkdsl.SymbolicDiscreteJaxOperator(hi, "guarded-piecewise")
    .for_each_site("i")
    .where(nkdsl.site("i") >= 1)
    .emit_if(nkdsl.site("i") == 1, nkdsl.write("i", 2), matrix_element=11.0)
    .emit_else(nkdsl.write("i", 3), matrix_element=33.0)
    .build()
    .compile()
)

for n in [0, 1, 2]:
    _xp3, mels3 = op_guarded.get_conn_padded(jnp.asarray([n], dtype=jnp.int32))
    print(f"guarded n={n} -> mels={mels3}")
guarded n=0 -> mels=[0. 0.]
guarded n=1 -> mels=[11.  0.]
guarded n=2 -> mels=[33.  0.]

Practical recommendations

  • Keep branch order intentional and document it in term names/tags.

  • Use small representative input probes to validate exclusivity.

  • Inspect IR to confirm branch predicates after lowering.

  • Prefer one clear conditional term over several partially overlapping terms when logic is piecewise.

This keeps both user-facing DSL and compiler-facing IR easier to reason about.