nkdsl.compiler.lowering.jax_lowerer

JAX backend symbolic operator lowerer.

Converts a SymbolicOperatorIR into a concrete compiled operator instance whose connectivity kernel is built by interpreting the AmplitudeExpr / PredicateExpr / UpdateProgram expression trees as JAX operations at trace time.

Architecture

For each IR term the lowerer generates a term runner: a Python function that, given a single input configuration x (shape [hilbert_size]), returns a tuple (x_primes, mels, valids) of shape (max_conn_size, hilbert_size), (max_conn_size,), (max_conn_size,) respectively.

K-body terms use a static index_array (shape [M, K]) and jax.vmap over its rows. Each row instantiates the iterator-label environment and evaluates all emissions, producing E branches per row (total M * E).

Branch-multiset note

By default, duplicate connected states are coalesced by summing matrix elements and dropping zero-amplitude entries (including invalidated branches). This behavior is controlled by SymbolicCompilerOptions.deduplicate_connected_components.

Functions

apply_single_update_op(op, x, env, ...[, ...])

Applies one update operation to a single configuration vector.

apply_update_program(x, program, env, ...[, ...])

Applies one update program to a single configuration vector.

build_compiled_operator(hilbert, *, ...[, ...])

Builds one compiled operator instance from prepared term runners.

build_default_symbolic_operator_lowering_registry()

Builds a registry with the default NetKet discrete JAX target.

create_compiled_operator(hilbert, *, name, ...)

Creates a compiled operator instance for the requested operator target.

debug_event(msg, *[, scope, pass_name, tag, ...])

Structured debug event with scope and optional pass filtering.

eval_amplitude(expr, env, *[, ...])

Evaluates one amplitude expression against an environment mapping.

eval_predicate(expr, env)

Evaluates one predicate expression against an environment mapping.

infer_shift_mod_spec_from_hilbert(hilbert)

Infers wrapped shift/mod metadata from one Hilbert space.

make_kbody_runner(term, hilbert_size, ...[, ...])

Builds one executable runner for a single K-body IR term.

parse_symbol_declaration_args(args)

Parses a symbol-expression payload into name + declaration map.

Classes

AbstractSymbolicLowerer()

Abstract base for backend-specific symbolic operator lowerers.

AmplitudeExpr(op[, args])

Typed expression node for operator matrix elements.

Any(*args, **kwargs)

Special type indicating an unconstrained type.

JAXSymbolicLowerer(*[, ...])

JAX-backend symbolic operator lowerer.

KBodyIteratorSpec(labels, index_sets)

Static K-body iterator over a pre-computed list of site-index tuples.

PredicateExpr(op[, args])

Typed boolean expression node for operator branch filtering.

SymbolicCompilationContext(*, operator, ir, ...)

Holds per-compilation mutable state across pipeline stages.

SymbolicCompiledArtifact(operator_name, ...)

Compilation artifact produced by the symbolic compiler pipeline.

SymbolicIRTerm(name, iterator, predicate, ...)

One primitive declarative symbolic operator term.

SymbolicOperatorLoweringRegistry([targets, ...])

Registry for selectable compiled-operator targets.

UpdateOp(kind[, params])

One primitive site-update operation.

UpdateProgram([ops])

Ordered immutable sequence of site-update operations.