Source code for nkdsl.compiler.compiler

# Copyright (c) 2026 The neuraLQX and nkDSL Authors - All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Symbolic operator compiler orchestrator."""

from __future__ import annotations

from typing import Any

from nkdsl.compiler.cache.store import (
    AbstractSymbolicArtifactStore,
)
from nkdsl.compiler.core.artifact import (
    SymbolicCompiledArtifact,
)
from nkdsl.compiler.core.context import (
    SymbolicCompilationContext,
)
from nkdsl.compiler.core.options import (
    SymbolicCompilerOptions,
)
from nkdsl.compiler.core.pipeline import (
    SymbolicPassPipeline,
)
from nkdsl.compiler.core.signature import (
    SymbolicCompilationSignature,
)
from nkdsl.compiler.lowering.registry import (
    SymbolicLowererRegistry,
)
from nkdsl.compiler.lowering.operator_registry import (
    SymbolicOperatorLoweringRegistry,
)
from nkdsl.debug import event as debug_event
from nkdsl.errors import SymbolicCompilerError


[docs] class SymbolicCompiler: """ Orchestrates the symbolic operator compilation pipeline. The compiler accepts a symbolic operator (an :class:`~nkdsl.core.base.AbstractSymbolicOperator`), runs it through the registered pass pipeline, optionally resolves a cache hit, and, on a miss, invokes the appropriate lowerer to produce a concrete executable operator instance. Typical usage:: from nkdsl import SymbolicCompiler compiler = SymbolicCompiler() compiled_op = compiler.compile_operator(my_symbolic_op) xp, mels = compiled_op.get_conn_padded(x_batch) Args: pipeline: Pass pipeline to use. Defaults to :func:`~nkdsl.compiler.defaults.default_symbolic_pass_pipeline`. lowerer_registry: Lowerer registry. Defaults to :func:`~nkdsl.compiler.defaults.default_symbolic_lowerer_registry`. operator_lowering_registry: Registry mapping operator-lowering names to target classes and connection methods. Defaults to :func:`~nkdsl.compiler.defaults.default_symbolic_operator_lowering_registry`. artifact_store: Artifact cache store. Defaults to the module-level shared :func:`~nkdsl.compiler.defaults.default_symbolic_artifact_store`. options: Compiler options. Defaults to :class:`SymbolicCompilerOptions` with all defaults. deduplicate_connected_components: Convenience override for ``options.deduplicate_connected_components``. operator_lowering: Convenience override for ``options.operator_lowering``. """
[docs] def __init__( self, *, pipeline: SymbolicPassPipeline | None = None, lowerer_registry: SymbolicLowererRegistry | None = None, operator_lowering_registry: SymbolicOperatorLoweringRegistry | None = None, artifact_store: AbstractSymbolicArtifactStore | None = None, options: SymbolicCompilerOptions | None = None, backend_preference: str | None = None, cache_enabled: bool | None = None, deduplicate_connected_components: bool | None = None, operator_lowering: str | None = None, ) -> None: from nkdsl.compiler.cache.store import InMemorySymbolicArtifactStore from nkdsl.compiler.defaults import ( default_symbolic_operator_lowering_registry, ) from nkdsl.compiler.defaults import ( default_symbolic_lowerer_registry, ) from nkdsl.compiler.defaults import ( default_symbolic_pass_pipeline, ) self._pipeline = pipeline or default_symbolic_pass_pipeline() self._operator_lowering_registry = ( operator_lowering_registry or default_symbolic_operator_lowering_registry() ) self._registry = lowerer_registry or default_symbolic_lowerer_registry( operator_lowering_registry=self._operator_lowering_registry ) self._store = artifact_store or InMemorySymbolicArtifactStore() resolved_options = options or SymbolicCompilerOptions() if ( backend_preference is not None or cache_enabled is not None or deduplicate_connected_components is not None or operator_lowering is not None ): resolved_options = SymbolicCompilerOptions( backend_preference=( backend_preference if backend_preference is not None else resolved_options.backend_preference ), deduplicate_connected_components=( deduplicate_connected_components if deduplicate_connected_components is not None else resolved_options.deduplicate_connected_components ), enable_fusion=resolved_options.enable_fusion, strict_validation=resolved_options.strict_validation, cache_enabled=( cache_enabled if cache_enabled is not None else resolved_options.cache_enabled ), cache_namespace=resolved_options.cache_namespace, operator_lowering=( operator_lowering if operator_lowering is not None else resolved_options.operator_lowering ), diagnostics_enabled=resolved_options.diagnostics_enabled, diagnostics_min_severity=resolved_options.diagnostics_min_severity, fail_on_warnings=resolved_options.fail_on_warnings, max_diagnostics=resolved_options.max_diagnostics, lint_state_sample_size=resolved_options.lint_state_sample_size, lint_branch_sample_cap=resolved_options.lint_branch_sample_cap, lint_max_exact_hilbert_states=resolved_options.lint_max_exact_hilbert_states, debug_flags=resolved_options.debug_flags, ) self._options = resolved_options
def compile( self, operator: Any, *, options: SymbolicCompilerOptions | None = None, metadata: dict[str, Any] | None = None, ) -> SymbolicCompiledArtifact: """ Compiles a symbolic operator to a :class:`SymbolicCompiledArtifact`. Steps: 1. Extracts the :class:`~nkdsl.ir.program.SymbolicOperatorIR` from the operator via ``to_ir()``. 2. Creates a :class:`~nkdsl.compiler.core.context.SymbolicCompilationContext`. 3. Runs the pre-cache pass stage. 4. Computes the cache key (when caching is enabled) and checks the artifact store. 5. On a cache hit: returns the cached artifact. 6. On a cache miss: runs post-cache passes, resolves a lowerer, lowers the operator, stores the artifact, and returns it. Args: operator: Symbolic operator with a ``to_ir()`` method. options: Override compiler options for this single invocation. metadata: Extra metadata forwarded to the compilation context. Returns: Compiled artifact. Raises: :class:`~nkdsl.errors.SymbolicCompilerError`: On unrecoverable compilation failure. """ effective_options = options or self._options debug_event( "starting symbolic compilation", scope="compile", tag="COMPILER", backend_preference=effective_options.backend_preference, operator_lowering=effective_options.operator_lowering, cache_enabled=effective_options.cache_enabled, strict_validation=effective_options.strict_validation, ) # Extract IR try: ir = operator.to_ir() except Exception as exc: raise SymbolicCompilerError( f"Failed to extract IR from operator {operator!r}: {exc}" ) from exc debug_event( "extracted symbolic ir from operator", scope="ir", tag="IR", operator_name=ir.operator_name, term_count=ir.term_count, ) # Build context context = SymbolicCompilationContext( operator=operator, ir=ir, options=effective_options, metadata=metadata, ) debug_event( "created compilation context", scope="compile", tag="COMPILER", operator_name=ir.operator_name, metadata_keys=tuple(sorted((metadata or {}).keys())), ) # Pre-cache passes try: self._pipeline.run_pre_cache(context) except Exception as exc: raise SymbolicCompilerError( f"Pre-cache pass failed for operator {ir.operator_name!r}: {exc}" ) from exc debug_event( "completed pre-cache pass stage", scope="compile", tag="COMPILER", operator_name=ir.operator_name, pass_count=len(context.pass_reports), ) # Cache lookup if effective_options.cache_enabled: sig = SymbolicCompilationSignature.from_context(context) cache_key = sig.build_cache_key( namespace=effective_options.cache_namespace, ) cached = self._store.get(cache_key) if cached is not None: debug_event( "cache hit", scope="cache", tag="CACHE", operator_name=ir.operator_name, cache_key=cache_key, ) return cached debug_event( "cache miss", scope="cache", tag="CACHE", operator_name=ir.operator_name, cache_key=cache_key, ) else: cache_key = None debug_event( "cache disabled for compile invocation", scope="cache", tag="CACHE", operator_name=ir.operator_name, ) # Post-cache passes try: self._pipeline.run_post_cache(context) except Exception as exc: raise SymbolicCompilerError( f"Post-cache pass failed for operator {ir.operator_name!r}: {exc}" ) from exc debug_event( "completed post-cache pass stage", scope="compile", tag="COMPILER", operator_name=ir.operator_name, pass_count=len(context.pass_reports), ) # Resolve lowerer and lower try: lowerer = self._registry.resolve(context) debug_event( "resolved lowerer", scope="lowering", tag="LOWERING", operator_name=ir.operator_name, lowerer_name=lowerer.name, backend=lowerer.backend, ) artifact = lowerer.lower(context) except Exception as exc: raise SymbolicCompilerError( f"Lowering failed for operator {ir.operator_name!r}: {exc}" ) from exc debug_event( "lowered symbolic operator", scope="lowering", tag="LOWERING", operator_name=ir.operator_name, selected_lowerer=context.selected_lowerer, backend=artifact.backend, ) # Attach cache key to artifact when caching is enabled if effective_options.cache_enabled and cache_key is not None: artifact = SymbolicCompiledArtifact.create( operator_name=artifact.operator_name, backend=artifact.backend, lowerer_name=artifact.lowerer_name, compiled_operator=artifact.compiled_operator, cache_key=cache_key, pass_reports=artifact.pass_reports, metadata=artifact.metadata_map(), ) self._store.put(cache_key, artifact) debug_event( "stored compiled artifact in cache", scope="cache", tag="CACHE", operator_name=ir.operator_name, cache_key=cache_key, cache_size=len(self._store), ) else: debug_event( "completed compilation without cache store", scope="compile", tag="COMPILER", operator_name=ir.operator_name, ) return artifact def compile_operator( self, operator: Any, *, options: SymbolicCompilerOptions | None = None, metadata: dict[str, Any] | None = None, ) -> Any: """ Compiles a symbolic operator and returns the executable operator directly. Convenience wrapper around :meth:`compile` that unwraps the artifact. Args: operator: Symbolic operator with a ``to_ir()`` method. options: Override compiler options. metadata: Extra metadata for the context. Returns: Executable compiled operator instance. """ debug_event( "compile_operator wrapper invoked", scope="compile", tag="COMPILER", ) artifact = self.compile(operator, options=options, metadata=metadata) return artifact.compiled_operator def clear_cache(self) -> None: """Clears all entries from the artifact store.""" self._store.clear() @property def cache_size(self) -> int: """Returns the number of cached artifacts.""" return len(self._store) @property def pass_names(self) -> tuple[str, ...]: """Returns the full ordered pass name sequence.""" return self._pipeline.pass_names() @property def lowerer_names(self) -> tuple[str, ...]: """Returns registered lowerer names.""" return self._registry.lowerer_names @property def operator_lowering_names(self) -> tuple[str, ...]: """Returns registered operator-lowering target names.""" return self._operator_lowering_registry.target_names def __repr__(self) -> str: return ( f"SymbolicCompiler(" f"passes={self.pass_names!r}, " f"lowerers={self.lowerer_names!r}, " f"operator_lowerings={self.operator_lowering_names!r}, " f"cache_size={self.cache_size})" )
_DEFAULT_COMPILER: SymbolicCompiler | None = None def compile_symbolic_operator( operator: Any, *, options: SymbolicCompilerOptions | None = None, metadata: dict[str, Any] | None = None, ) -> Any: """ Module-level convenience function for one-shot symbolic compilation. Uses the module-level shared :class:`SymbolicCompiler` instance (lazily created). The shared compiler reuses the global in-process artifact cache. Args: operator: Symbolic operator with a ``to_ir()`` method. options: Override compiler options. metadata: Extra metadata for the context. Returns: Executable compiled operator instance. Example:: from nkdsl import compile_symbolic_operator compiled_op = compile_symbolic_operator(my_symbolic_op) xp, mels = compiled_op.get_conn_padded(x_batch) """ global _DEFAULT_COMPILER # noqa: PLW0603 if _DEFAULT_COMPILER is None: _DEFAULT_COMPILER = SymbolicCompiler() debug_event( "initialized default symbolic compiler", scope="compile", tag="COMPILER", ) return _DEFAULT_COMPILER.compile_operator(operator, options=options, metadata=metadata) def reset_default_symbolic_compiler() -> None: """ Resets the module-level shared compiler instance. This helper primarily exists for tests that need deterministic singleton lifecycle behavior across cases. """ global _DEFAULT_COMPILER # noqa: PLW0603 _DEFAULT_COMPILER = None __all__ = [ "SymbolicCompiler", "compile_symbolic_operator", "reset_default_symbolic_compiler", ]