Skip to content
7 changes: 7 additions & 0 deletions src/bloqade/shuttle/analysis/aod/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .analysis import AODAnalysis as AODAnalysis
from .lattice import (
AOD as AOD,
AODState as AODState,
NotAOD as NotAOD,
Unknown as Unknown,
)
46 changes: 46 additions & 0 deletions src/bloqade/shuttle/analysis/aod/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Type, TypeVar

from kirin import interp, ir
from kirin.analysis import const
from kirin.analysis.forward import Forward, ForwardFrame

from bloqade.shuttle.arch import ArchSpecMixin
from bloqade.shuttle.dialects import tracking

from .lattice import AODState


@dataclass
class AODAnalysis(Forward[AODState], ArchSpecMixin):

keys = ["aod.analysis", "spec.interp"]
lattice = AODState

T = TypeVar("T")

def get_const_value(self, typ: Type[T], ssa: ir.SSAValue) -> T:
if not isinstance(value := ssa.hints.get("const"), const.Value):
raise interp.InterpreterError(
"Non-constant value encountered in AOD analysis."
)

if not isinstance(data := value.data, typ):
raise interp.InterpreterError(
f"Expected constant of type {typ}, got {type(data)}."
)

return data

def eval_stmt_fallback(
self, frame: ForwardFrame[AODState], stmt: ir.Statement
) -> tuple[AODState, ...] | interp.SpecialValue[AODState]:
return tuple(
AODState.top()
for result in stmt.results
if result.type.is_subseteq(tracking.SystemStateType)
)

def run_method(self, method: ir.Method, args: tuple[AODState, ...]):
# NOTE: we do not support dynamic calls here, thus no need to propagate method object
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
62 changes: 62 additions & 0 deletions src/bloqade/shuttle/analysis/aod/lattice.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from dataclasses import dataclass

from bloqade.geometry.dialects import grid
from kirin import ir
from kirin.dialects import ilist
from kirin.ir.attrs.abc import LatticeAttributeMeta
from kirin.lattice.abc import BoundedLattice
from kirin.lattice.mixin import SimpleJoinMixin, SimpleMeetMixin
from kirin.print.printer import Printer


@dataclass
class AODState(
ir.Attribute,
SimpleJoinMixin["AODState"],
SimpleMeetMixin["AODState"],
BoundedLattice["AODState"],
metaclass=LatticeAttributeMeta,
):

@classmethod
def bottom(cls) -> "AODState":
return NotAOD()

@classmethod
def top(cls) -> "AODState":
return Unknown()

def print_impl(self, printer: Printer) -> None:
printer.print(self.__class__.__name__ + "()")


@dataclass
class NotAOD(AODState):
def is_subseteq(self, other: AODState) -> bool:
return True


@dataclass
class Unknown(AODState):
def is_subseteq(self, other: AODState) -> bool:
return isinstance(other, Unknown)


@dataclass
class AOD(AODState):
x_tones: frozenset[int]
y_tones: frozenset[int]
pos: grid.Grid

def active_positions(self) -> grid.Grid:
x_indices = ilist.IList(sorted(self.x_tones))
y_indices = ilist.IList(sorted(self.y_tones))
return self.pos.get_view(x_indices, y_indices)

def is_subseteq(self, other: AODState) -> bool:
# only check of the active AOD positions are equal not necessarily
# the exact positions
return (
isinstance(other, AOD)
and self.active_positions() == other.active_positions()
)
23 changes: 14 additions & 9 deletions src/bloqade/shuttle/analysis/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@ class RuntimeFrame(ForwardFrame[EmptyLattice]):
This frame is used to track the state of quantum operations within a method.
"""

quantum_stmts: set[ir.Statement] = field(default_factory=set)
quantum_call: set[ir.Statement] = field(default_factory=set)
"""Set of quantum statements in the frame."""
is_quantum: bool = False
"""Whether the frame contains quantum operations."""

def merge_runtime(self, other: "RuntimeFrame", stmt: ir.Statement):
if other.is_quantum:
self.is_quantum = True
self.quantum_call.add(stmt)
self.quantum_call.update(other.quantum_call)


class RuntimeAnalysis(ForwardExtra[RuntimeFrame, EmptyLattice]):
"""Forward dataflow analysis to check if a method has quantum runtime.
Expand Down Expand Up @@ -61,10 +67,8 @@ def ifelse(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.IfElse
else_frame, stmt.else_body, (_interp.lattice.top(),)
)

frame.is_quantum = (
frame.is_quantum or then_frame.is_quantum or else_frame.is_quantum
)
frame.quantum_stmts.update(then_frame.quantum_stmts, else_frame.quantum_stmts)
frame.merge_runtime(then_frame, stmt)
frame.merge_runtime(else_frame, stmt)
match (then_result, else_result):
case (interp.ReturnValue(), tuple()):
return else_result
Expand All @@ -86,8 +90,7 @@ def for_loop(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: scf.For)
body_frame, stmt.body, (_interp.lattice.bottom(),)
)

frame.is_quantum = frame.is_quantum or body_frame.is_quantum
frame.quantum_stmts.update(body_frame.quantum_stmts)
frame.merge_runtime(body_frame, stmt)
if isinstance(result, interp.ReturnValue) or result is None:
return args[1:]
else:
Expand All @@ -107,7 +110,8 @@ class Func(interp.MethodTable):
def invoke(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Invoke):
args = (_interp.lattice.top(),) * len(stmt.inputs)
callee_frame, result = _interp.run_method(stmt.callee, args)
frame.is_quantum = frame.is_quantum or callee_frame.is_quantum
frame.merge_runtime(callee_frame, stmt)

return (result,)

@interp.impl(func.Call)
Expand All @@ -123,10 +127,11 @@ def call(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: func.Call):
body = trait.get_callable_region(callee_result.code)
with _interp.new_frame(stmt) as callee_frame:
result = _interp.run_ssacfg_region(callee_frame, body, args)

else:
raise InterruptedError("Dynamic method calls are not supported")

frame.is_quantum = frame.is_quantum or callee_frame.is_quantum
frame.merge_runtime(callee_frame, stmt)
return (result,)

@interp.impl(func.Return)
Expand Down
4 changes: 3 additions & 1 deletion src/bloqade/shuttle/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class ArchSpec:
layout: Layout = field(default_factory=_default_layout) # type: ignore
float_constants: dict[str, float] = field(default_factory=dict)
int_constants: dict[str, int] = field(default_factory=dict)
max_x_tones: int = field(default=16, kw_only=True)
max_y_tones: int = field(default=16, kw_only=True)

def __hash__(self):
return hash(
Expand All @@ -126,7 +128,7 @@ def __hash__(self):
class ArchSpecMixin:
"""Base class for interpreters that require an architecture specification."""

arch_spec: ArchSpec
arch_spec: ArchSpec = field(kw_only=True)


@dataclass
Expand Down
72 changes: 0 additions & 72 deletions src/bloqade/shuttle/codegen/taskgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from bloqade.geometry.dialects import grid
from kirin import ir
from kirin.dialects import func, ilist
from kirin.interp import Frame, InterpreterError, MethodTable, impl
from kirin.ir.method import Method
from typing_extensions import Self

Expand Down Expand Up @@ -152,74 +151,3 @@ def run_trace(
# TODO: use permute_values to get correct order.
super().run(mt, args=args, kwargs=kwargs)
return self.trace.copy()


@action.dialect.register(key="action.tracer")
class ActionTracer(MethodTable):

intensity_actions = {
action.TurnOnXY: TurnOnXYAction,
action.TurnOffXY: TurnOffXYAction,
action.TurnOnXSlice: TurnOnXSliceAction,
action.TurnOffXSlice: TurnOffXSliceAction,
action.TurnOnYSlice: TurnOnYSliceAction,
action.TurnOffYSlice: TurnOffYSliceAction,
action.TurnOnXYSlice: TurnOnXYSliceAction,
action.TurnOffXYSlice: TurnOffXYSliceAction,
}

@impl(action.TurnOnXY)
@impl(action.TurnOffXY)
@impl(action.TurnOnXSlice)
@impl(action.TurnOffXSlice)
@impl(action.TurnOnYSlice)
@impl(action.TurnOffYSlice)
@impl(action.TurnOnXYSlice)
@impl(action.TurnOffXYSlice)
def construct_intensity_actions(
self,
interp: TraceInterpreter,
frame: Frame,
stmt: action.IntensityStatement,
):
if interp.curr_pos is None:
raise InterpreterError(
"Position of AOD not set before turning on/off tones"
)

x_tone_indices = frame.get(stmt.x_tones)
y_tone_indices = frame.get(stmt.y_tones)

interp.trace.append(
self.intensity_actions[type(stmt)](
x_tone_indices if isinstance(x_tone_indices, slice) else x_tone_indices,
y_tone_indices if isinstance(y_tone_indices, slice) else y_tone_indices,
)
)
interp.trace.append(WayPointsAction(way_points=[interp.curr_pos]))
return ()

@impl(action.Move)
def move(self, interp: TraceInterpreter, frame: Frame, stmt: action.Move):
if interp.curr_pos is None:
raise InterpreterError("Position of AOD not set before moving tones")

assert isinstance(interp.trace[-1], WayPointsAction)

interp.trace[-1].add_waypoint(pos := frame.get_typed(stmt.grid, grid.Grid))
if interp.curr_pos.shape != pos.shape:
raise InterpreterError(
f"Position of AOD {interp.curr_pos} and target position {pos} have different shapes"
)
interp.curr_pos = pos

return ()

@impl(action.Set)
def set(self, interp: TraceInterpreter, frame: Frame, stmt: action.Set):
pos = frame.get_typed(stmt.grid, grid.Grid)
interp.trace.append(WayPointsAction([pos]))

interp.curr_pos = pos

return ()
1 change: 1 addition & 0 deletions src/bloqade/shuttle/dialects/action/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@
TurnOnYSlice as TurnOnYSlice,
TweezerFunction as TweezerFunction,
)
from .trace import ActionTracer as ActionTracer
78 changes: 78 additions & 0 deletions src/bloqade/shuttle/dialects/action/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from bloqade.geometry.dialects import grid
from kirin.interp import Frame, InterpreterError, MethodTable, impl

from bloqade.shuttle.codegen import taskgen

from . import stmts
from ._dialect import dialect


@dialect.register(key="action.tracer")
class ActionTracer(MethodTable):

intensity_actions = {
stmts.TurnOnXY: taskgen.TurnOnXYAction,
stmts.TurnOffXY: taskgen.TurnOffXYAction,
stmts.TurnOnXSlice: taskgen.TurnOnXSliceAction,
stmts.TurnOffXSlice: taskgen.TurnOffXSliceAction,
stmts.TurnOnYSlice: taskgen.TurnOnYSliceAction,
stmts.TurnOffYSlice: taskgen.TurnOffYSliceAction,
stmts.TurnOnXYSlice: taskgen.TurnOnXYSliceAction,
stmts.TurnOffXYSlice: taskgen.TurnOffXYSliceAction,
}

@impl(stmts.TurnOnXY)
@impl(stmts.TurnOffXY)
@impl(stmts.TurnOnXSlice)
@impl(stmts.TurnOffXSlice)
@impl(stmts.TurnOnYSlice)
@impl(stmts.TurnOffYSlice)
@impl(stmts.TurnOnXYSlice)
@impl(stmts.TurnOffXYSlice)
def construct_intensity_actions(
self,
interp: taskgen.TraceInterpreter,
frame: Frame,
stmt: stmts.IntensityStatement,
):
if interp.curr_pos is None:
raise InterpreterError(
"Position of AOD not set before turning on/off tones"
)

x_tone_indices = frame.get(stmt.x_tones)
y_tone_indices = frame.get(stmt.y_tones)

interp.trace.append(
self.intensity_actions[type(stmt)](
x_tone_indices if isinstance(x_tone_indices, slice) else x_tone_indices,
y_tone_indices if isinstance(y_tone_indices, slice) else y_tone_indices,
)
)
interp.trace.append(taskgen.WayPointsAction(way_points=[interp.curr_pos]))
return ()

@impl(stmts.Move)
def move(self, interp: taskgen.TraceInterpreter, frame: Frame, stmt: stmts.Move):
if interp.curr_pos is None:
raise InterpreterError("Position of AOD not set before moving tones")

assert isinstance(interp.trace[-1], taskgen.WayPointsAction)

interp.trace[-1].add_waypoint(pos := frame.get_typed(stmt.grid, grid.Grid))
if interp.curr_pos.shape != pos.shape:
raise InterpreterError(
f"Position of AOD {interp.curr_pos} and target position {pos} have different shapes"
)
interp.curr_pos = pos

return ()

@impl(stmts.Set)
def set(self, interp: taskgen.TraceInterpreter, frame: Frame, stmt: stmts.Set):
pos = frame.get_typed(stmt.grid, grid.Grid)
interp.trace.append(taskgen.WayPointsAction([pos]))

interp.curr_pos = pos

return ()
2 changes: 1 addition & 1 deletion src/bloqade/shuttle/dialects/gate/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,5 @@ def gate(
) -> interp.StatementResult[RuntimeFrame]:
"""Handle gate statements and mark the frame as quantum."""
frame.is_quantum = True
frame.quantum_stmts.add(stmt)
frame.quantum_call.add(stmt)
return ()
2 changes: 1 addition & 1 deletion src/bloqade/shuttle/dialects/init/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ class HasQuantumRuntimeMethodTable(interp.MethodTable):
def gate(self, interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Fill):
"""Handle gate statements and mark the frame as quantum."""
frame.is_quantum = True
frame.quantum_stmts.add(stmt)
frame.quantum_call.add(stmt)
return ()
2 changes: 0 additions & 2 deletions src/bloqade/shuttle/dialects/measure/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,4 @@
from .types import (
MeasurementArray as MeasurementArray,
MeasurementArrayType as MeasurementArrayType,
MeasurementResult as MeasurementResult,
MeasurementResultType as MeasurementResultType,
)
2 changes: 1 addition & 1 deletion src/bloqade/shuttle/dialects/measure/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ class HasQuantumRuntimeMethodTable(interp.MethodTable):
def gate(self, _interp: RuntimeAnalysis, frame: RuntimeFrame, stmt: Measure):
"""Handle gate statements and mark the frame as quantum."""
frame.is_quantum = True
frame.quantum_stmts.add(stmt)
frame.quantum_call.add(stmt)
return (_interp.lattice.top(),)
Loading
Loading