"""Loop-structured single-shooting optimal-control abstractions."""
from __future__ import annotations
from dataclasses import dataclass, replace
from .function import Function
from ._staged import _create_rust_project
from ._staged import _generate_rust
from ._staged import _simplify_function
from .sx import SX, SXVector
FunctionArg = SX | SXVector
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingBundle:
"""Requested outputs for a joint single-shooting kernel."""
include_cost: bool = False
include_gradient: bool = False
include_hvp: bool = False
include_states: bool = False
[docs]
def add_cost(self) -> SingleShootingBundle:
"""Include the total cost in the joint kernel outputs."""
return replace(self, include_cost=True)
[docs]
def add_gradient(self) -> SingleShootingBundle:
"""Include the gradient with respect to the packed control sequence."""
return replace(self, include_gradient=True)
[docs]
def add_hvp(self) -> SingleShootingBundle:
"""Include the HVP with respect to the packed control sequence."""
return replace(self, include_hvp=True)
[docs]
def add_rollout_states(self) -> SingleShootingBundle:
"""Include the packed rollout state trajectory."""
return replace(self, include_states=True)
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingPrimalFunction:
"""Primal single-shooting cost kernel."""
problem: SingleShootingProblem
name: str
include_states: bool = False
simplification: int | str | None = None
[docs]
def to_function(self, name: str | None = None) -> Function:
"""Expand this staged kernel into a symbolic ``Function``."""
function = self.problem.to_function(
include_states=self.include_states,
name=name or self.name,
)
return _simplify_function(function, self.simplification)
@property
def nodes(self):
"""Return dependency nodes for shared helper discovery."""
return self.to_function().nodes
@property
def input_names(self) -> tuple[str, ...]:
"""Return the exposed input names."""
return self.problem.input_names
@property
def output_names(self) -> tuple[str, ...]:
"""Return the exposed output names."""
names = ["cost"]
if self.include_states:
names.append("x_traj")
return tuple(names)
[docs]
def generate_rust(
self,
*,
config=None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Generate compact Rust for the staged primal kernel."""
return _generate_rust(
self,
config=config,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
def create_rust_project(
self,
path: str,
*,
config=None,
crate_name: str | None = None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Create a Rust crate containing the staged primal kernel."""
return _create_rust_project(
self,
path,
config=config,
crate_name=crate_name,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingGradientFunction:
"""Gradient kernel for a single-shooting optimal-control problem."""
problem: SingleShootingProblem
name: str
include_states: bool = False
simplification: int | str | None = None
[docs]
def to_function(self, name: str | None = None) -> Function:
"""Expand this staged gradient into a regular symbolic ``Function``."""
gradient_function = self.problem._expanded_gradient_function(
include_states=self.include_states,
name=name or self.name,
)
return _simplify_function(gradient_function, self.simplification)
@property
def nodes(self):
"""Return dependency nodes for shared helper discovery."""
return self.to_function().nodes
@property
def input_names(self) -> tuple[str, ...]:
"""Return the exposed input names."""
return self.problem.input_names
@property
def output_names(self) -> tuple[str, ...]:
"""Return the exposed output names."""
names = [f"gradient_{self.problem.control_sequence_name}"]
if self.include_states:
names.append("x_traj")
return tuple(names)
[docs]
def generate_rust(
self,
*,
config=None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Generate compact Rust for the staged gradient kernel."""
return _generate_rust(
self,
config=config,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
def create_rust_project(
self,
path: str,
*,
config=None,
crate_name: str | None = None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Create a Rust crate containing the staged gradient kernel."""
return _create_rust_project(
self,
path,
config=config,
crate_name=crate_name,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingHvpFunction:
"""HVP kernel for a single-shooting optimal-control problem."""
problem: SingleShootingProblem
name: str
include_states: bool = False
simplification: int | str | None = None
[docs]
def to_function(self, name: str | None = None) -> Function:
"""Expand this staged HVP kernel into a symbolic ``Function``."""
hvp_function = self.problem._expanded_hvp_function(
include_states=self.include_states,
name=name or self.name,
)
return _simplify_function(hvp_function, self.simplification)
@property
def nodes(self):
"""Return dependency nodes for shared helper discovery."""
return self.to_function().nodes
@property
def input_names(self) -> tuple[str, ...]:
"""Return the exposed input names."""
return (
*self.problem.input_names,
f"v_{self.problem.control_sequence_name}",
)
@property
def output_names(self) -> tuple[str, ...]:
"""Return the exposed output names."""
names = [f"hvp_{self.problem.control_sequence_name}"]
if self.include_states:
names.append("x_traj")
return tuple(names)
[docs]
def generate_rust(
self,
*,
config=None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Generate compact Rust for the staged HVP kernel."""
return _generate_rust(
self,
config=config,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
def create_rust_project(
self,
path: str,
*,
config=None,
crate_name: str | None = None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Create a Rust crate containing the staged HVP kernel."""
return _create_rust_project(
self,
path,
config=config,
crate_name=crate_name,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingJointFunction:
"""Joint cost/gradient/HVP/state kernel for a single-shooting problem."""
problem: SingleShootingProblem
bundle: SingleShootingBundle
name: str
simplification: int | str | None = None
[docs]
def to_function(self, name: str | None = None) -> Function:
"""Expand this staged joint kernel into a symbolic ``Function``."""
function = self.problem._expanded_joint_function(
self.bundle,
name=name or self.name,
)
return _simplify_function(function, self.simplification)
@property
def nodes(self):
"""Return dependency nodes for shared helper discovery."""
return self.to_function().nodes
@property
def input_names(self) -> tuple[str, ...]:
"""Return the exposed input names."""
return self.problem.input_names
@property
def output_names(self) -> tuple[str, ...]:
"""Return the exposed output names."""
return _single_shooting_bundle_output_names(self.problem, self.bundle)
[docs]
def generate_rust(
self,
*,
config=None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Generate compact Rust for the staged joint kernel."""
return _generate_rust(
self,
config=config,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
def create_rust_project(
self,
path: str,
*,
config=None,
crate_name: str | None = None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Create a Rust crate containing the staged joint kernel."""
return _create_rust_project(
self,
path,
config=config,
crate_name=crate_name,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
@dataclass(frozen=True, slots=True)
class SingleShootingProblem:
"""Deterministic single-shooting optimal-control problem.
The problem represents a fixed-horizon rollout with dynamics
``x_next = f(x, u, p)`` and total cost
``sum_k ell(x_k, u_k, p) + V_f(x_N, p)``. Optional vector-valued
penalty residuals may be supplied to augment this cost as
``c / 2 * ||q(x_k, u_k, p)||_2^2`` at each stage and
``c / 2 * ||q_N(x_N, p)||_2^2`` at the terminal state.
Args:
name: Name used for expanded functions and generated kernels.
horizon: Optional positive rollout horizon. It may be supplied in the
constructor or later with :meth:`with_horizon`.
dynamics: Optional function accepting ``(x, u, p)`` and returning the
next state with the same shape as ``x``. It may be supplied in the
constructor or later with :meth:`with_dynamics`.
stage_cost: Optional scalar function accepting ``(x, u, p)`` and
returning ``ell(x, u, p)``. It may be supplied in the constructor
or later with :meth:`with_stage_cost`.
terminal_cost: Optional scalar function accepting ``(x, p)`` and
returning ``V_f(x, p)``. It may be supplied in the constructor or
later with :meth:`with_terminal_cost`.
initial_state_name: Runtime input name for the initial state.
control_sequence_name: Runtime input name for the packed controls.
parameter_name: Runtime input name for the shared parameters.
simplification: Optional simplification effort applied to expanded
functions and derivative helper kernels.
stage_penalty: Optional vector-valued residual function accepting
``(x, u, p)`` and returning ``q(x, u, p)``.
terminal_penalty: Optional vector-valued residual function accepting
``(x, p)`` and returning ``q_N(x, p)``.
penalty_weight: Optional scalar ``c`` multiplying both squared
residual norms. Pass a numeric value to bake ``c`` into generated
code, or pass an ``SX`` symbol such as ``SX.sym("c")`` to expose
``c`` as a runtime scalar input.
"""
name: str
horizon: int | None = None
dynamics: Function | None = None
stage_cost: Function | None = None
terminal_cost: Function | None = None
initial_state_name: str = "x0"
control_sequence_name: str = "U"
parameter_name: str = "p"
simplification: int | str | None = None
stage_penalty: Function | None = None
terminal_penalty: Function | None = None
penalty_weight: float | SX | None = None
def __post_init__(self) -> None:
"""Validate stage and cost signatures."""
if self.horizon is not None and self.horizon <= 0:
raise ValueError("horizon must be a positive integer")
if self.dynamics is not None:
_validate_single_shooting_dynamics(self.dynamics)
if self.stage_cost is not None:
_validate_single_shooting_stage_cost(self.stage_cost)
if self.terminal_cost is not None:
_validate_single_shooting_terminal_cost(self.terminal_cost)
if self._has_complete_core():
_validate_single_shooting_penalty_configuration(
self.stage_penalty,
self.terminal_penalty,
self.penalty_weight,
)
if self.stage_penalty is not None:
_validate_single_shooting_stage_penalty(self.stage_penalty)
if self.terminal_penalty is not None:
_validate_single_shooting_terminal_penalty(
self.terminal_penalty
)
if self._has_complete_core():
_validate_single_shooting_shapes(
self._dynamics(),
self._stage_cost(),
self._terminal_cost(),
self.stage_penalty,
self.terminal_penalty,
)
[docs]
def with_horizon(self, horizon: int) -> SingleShootingProblem:
"""Return a copy configured with a positive rollout horizon.
Args:
horizon: Number of control intervals in the single-shooting
rollout. It must be a positive integer.
Returns:
A new :class:`SingleShootingProblem` with the requested horizon.
"""
return replace(self, horizon=horizon)
[docs]
def with_dynamics(self, dynamics: Function) -> SingleShootingProblem:
"""Return a copy configured with the dynamics function.
Args:
dynamics: Function accepting ``(x, u, p)`` and returning the next
state with the same shape as ``x``.
Returns:
A new :class:`SingleShootingProblem` using ``dynamics``.
"""
return replace(self, dynamics=dynamics)
[docs]
def with_stage_cost(self, stage_cost: Function) -> SingleShootingProblem:
"""Return a copy configured with the scalar stage cost.
Args:
stage_cost: Function accepting ``(x, u, p)`` and returning one
scalar output.
Returns:
A new :class:`SingleShootingProblem` using ``stage_cost``.
"""
return replace(self, stage_cost=stage_cost)
[docs]
def with_terminal_cost(
self, terminal_cost: Function
) -> SingleShootingProblem:
"""Return a copy configured with the scalar terminal cost.
Args:
terminal_cost: Function accepting ``(x, p)`` and returning one
scalar output.
Returns:
A new :class:`SingleShootingProblem` using ``terminal_cost``.
"""
return replace(self, terminal_cost=terminal_cost)
[docs]
def with_costs(
self,
stage_cost: Function,
terminal_cost: Function,
) -> SingleShootingProblem:
"""Return a copy configured with stage and terminal costs.
Args:
stage_cost: Scalar stage cost accepting ``(x, u, p)``.
terminal_cost: Scalar terminal cost accepting ``(x, p)``.
Returns:
A new :class:`SingleShootingProblem` using both costs.
"""
return replace(
self,
stage_cost=stage_cost,
terminal_cost=terminal_cost,
)
[docs]
def with_penalties(
self,
stage_penalty: Function,
terminal_penalty: Function,
penalty_weight: float | SX,
) -> SingleShootingProblem:
"""Return a copy configured with residual penalties.
Args:
stage_penalty: Vector or scalar residual accepting ``(x, u, p)``.
terminal_penalty: Vector or scalar residual accepting ``(x, p)``.
penalty_weight: Numeric penalty weight, or an ``SX`` symbol such
as ``SX.sym("c")`` to expose the weight as a runtime input.
Returns:
A new :class:`SingleShootingProblem` with residual penalties.
"""
if penalty_weight is None:
raise ValueError("penalty_weight must be provided")
return replace(
self,
stage_penalty=stage_penalty,
terminal_penalty=terminal_penalty,
penalty_weight=penalty_weight,
)
[docs]
def with_simplification(
self, simplification: int | str | None
) -> SingleShootingProblem:
"""Return a copy with the requested simplification effort.
Args:
simplification: Simplification effort used for expanded functions
and derivative helper kernels.
Returns:
A new :class:`SingleShootingProblem` with the simplification
setting.
"""
return replace(self, simplification=simplification)
@property
def state_size(self) -> int:
"""Return the state dimension."""
return _single_shooting_arg_size(self._dynamics().inputs[0])
@property
def control_size(self) -> int:
"""Return the per-stage control dimension."""
return _single_shooting_arg_size(self._dynamics().inputs[1])
@property
def parameter_size(self) -> int:
"""Return the shared parameter-vector dimension."""
return _single_shooting_arg_size(self._dynamics().inputs[2])
@property
def has_runtime_penalty_weight(self) -> bool:
"""Return whether ``c`` is exposed as a runtime scalar input."""
return isinstance(self.penalty_weight, SX)
@property
def penalty_weight_name(self) -> str | None:
"""Return the runtime input name for symbolic ``c``."""
if isinstance(self.penalty_weight, SX):
return self.penalty_weight.name or "c"
return None
@property
def input_names(self) -> tuple[str, ...]:
"""Return the exposed runtime input names."""
names = [
self.initial_state_name,
self.control_sequence_name,
self.parameter_name,
]
if self.has_runtime_penalty_weight:
penalty_weight_name = self.penalty_weight_name
if penalty_weight_name is not None:
names.append(penalty_weight_name)
return tuple(names)
@property
def output_names(self) -> tuple[str, ...]:
"""Return the default primal output names."""
return ("cost",)
@property
def inputs(self) -> tuple[FunctionArg, ...]:
"""Return compiled symbolic inputs."""
return self._compiled_inputs()
@property
def outputs(self) -> tuple[FunctionArg, ...]:
"""Return compiled symbolic outputs for the primal cost kernel."""
return self.to_function().outputs
@property
def nodes(self):
"""Return dependency nodes for shared helper discovery."""
return self.to_function().nodes
def _has_complete_core(self) -> bool:
"""Return whether the required problem definition is present."""
return (
self.horizon is not None
and self.dynamics is not None
and self.stage_cost is not None
and self.terminal_cost is not None
)
def _require_complete(self) -> None:
"""Raise a helpful error when a builder-style problem is incomplete."""
missing: list[str] = []
if self.horizon is None:
missing.append("horizon")
if self.dynamics is None:
missing.append("dynamics")
if self.stage_cost is None:
missing.append("stage_cost")
if self.terminal_cost is None:
missing.append("terminal_cost")
if missing:
joined = ", ".join(missing)
raise ValueError(
"SingleShootingProblem is incomplete; configure "
f"{joined} before expanding or generating code"
)
def _dynamics(self) -> Function:
"""Return configured dynamics after checking completeness."""
self._require_complete()
assert self.dynamics is not None
return self.dynamics
def _stage_cost(self) -> Function:
"""Return configured stage cost after checking completeness."""
self._require_complete()
assert self.stage_cost is not None
return self.stage_cost
def _terminal_cost(self) -> Function:
"""Return configured terminal cost after checking completeness."""
self._require_complete()
assert self.terminal_cost is not None
return self.terminal_cost
def _horizon(self) -> int:
"""Return configured horizon after checking completeness."""
self._require_complete()
assert self.horizon is not None
return self.horizon
[docs]
def primal(
self,
*,
include_states: bool = False,
name: str | None = None,
) -> SingleShootingProblem | SingleShootingPrimalFunction:
"""Return a staged primal kernel source."""
self._require_complete()
if not include_states and name is None:
return self
return SingleShootingPrimalFunction(
problem=self,
name=name
or _single_shooting_primal_name(self.name, include_states),
include_states=include_states,
simplification=self.simplification,
)
[docs]
def gradient(
self,
*,
include_states: bool = False,
name: str | None = None,
) -> SingleShootingGradientFunction:
"""Return a staged gradient kernel source."""
self._require_complete()
return SingleShootingGradientFunction(
problem=self,
name=name or f"{self.name}_gradient_{self.control_sequence_name}",
include_states=include_states,
simplification=self.simplification,
)
[docs]
def hvp(
self,
*,
include_states: bool = False,
name: str | None = None,
) -> SingleShootingHvpFunction:
"""Return a staged Hessian-vector-product kernel source."""
self._require_complete()
return SingleShootingHvpFunction(
problem=self,
name=name or f"{self.name}_hvp_{self.control_sequence_name}",
include_states=include_states,
simplification=self.simplification,
)
[docs]
def joint(
self,
bundle: SingleShootingBundle,
*,
name: str | None = None,
) -> SingleShootingJointFunction:
"""Return a staged joint kernel source."""
self._require_complete()
_validate_single_shooting_bundle(bundle)
return SingleShootingJointFunction(
problem=self,
bundle=bundle,
name=name
or _single_shooting_joint_name(
self.name, bundle, self.control_sequence_name
),
simplification=self.simplification,
)
[docs]
def to_function(
self,
*,
include_states: bool = False,
name: str | None = None,
) -> Function:
"""Expand the total-cost kernel into a symbolic ``Function``."""
self._require_complete()
dynamics = self._dynamics()
horizon = self._horizon()
compiled_inputs = self._compiled_inputs()
x0, U, p = compiled_inputs[:3]
penalty_weight = (
compiled_inputs[3]
if self.has_runtime_penalty_weight
else None
)
current_state: FunctionArg = x0
rollout_states: list[FunctionArg] = [current_state]
total_cost = SX.const(0.0)
for stage_index in range(horizon):
u_t = _slice_packed_sequence(
U, stage_index, self.control_size, dynamics.inputs[1]
)
total_cost = total_cost + self._stage_total_cost(
current_state, u_t, p, penalty_weight
)
current_state = _extract_single_output(
dynamics(current_state, u_t, p)
)
rollout_states.append(current_state)
total_cost = total_cost + self._terminal_total_cost(
current_state, p, penalty_weight
)
outputs: list[FunctionArg] = [total_cost]
output_names = ["cost"]
if include_states:
outputs.append(_flatten_rollout_states(rollout_states))
output_names.append("x_traj")
function = Function(
name or self.name,
compiled_inputs,
outputs,
input_names=self.input_names,
output_names=tuple(output_names),
)
return _simplify_function(function, self.simplification)
def _expanded_gradient_function(
self,
*,
include_states: bool,
name: str,
) -> Function:
"""Expand the staged gradient kernel into a symbolic ``Function``."""
cost_function = self.to_function(
include_states=False, name=f"{name}_cost"
)
gradient_function = cost_function.gradient(1, name=f"{name}_grad")
outputs: list[FunctionArg] = [gradient_function.outputs[0]]
output_names = [f"gradient_{self.control_sequence_name}"]
if include_states:
outputs.append(
self.to_function(
include_states=True, name=f"{name}_states"
).outputs[1]
)
output_names.append("x_traj")
return Function(
name,
gradient_function.inputs,
outputs,
input_names=gradient_function.input_names,
output_names=tuple(output_names),
)
def _expanded_hvp_function(
self,
*,
include_states: bool,
name: str,
) -> Function:
"""Expand the staged HVP kernel into a symbolic ``Function``."""
cost_function = self.to_function(
include_states=False, name=f"{name}_cost"
)
hvp_function = cost_function.hvp(1, name=f"{name}_hvp")
outputs: list[FunctionArg] = [hvp_function.outputs[0]]
output_names = [f"hvp_{self.control_sequence_name}"]
if include_states:
outputs.append(
self.to_function(
include_states=True, name=f"{name}_states"
).outputs[1]
)
output_names.append("x_traj")
return Function(
name,
hvp_function.inputs,
outputs,
input_names=hvp_function.input_names,
output_names=tuple(output_names),
)
def _expanded_joint_function(
self,
bundle: SingleShootingBundle,
*,
name: str,
) -> Function:
"""Expand a staged joint kernel into a symbolic ``Function``."""
_validate_single_shooting_bundle(bundle)
cost_function = self.to_function(
include_states=False, name=f"{name}_cost"
)
gradient_function = cost_function.gradient(1, name=f"{name}_grad")
hvp_function = (
cost_function.hvp(1, name=f"{name}_hvp")
if bundle.include_hvp
else None
)
outputs: list[FunctionArg] = []
output_names: list[str] = []
if bundle.include_cost:
outputs.append(cost_function.outputs[0])
output_names.append("cost")
if bundle.include_gradient:
outputs.append(gradient_function.outputs[0])
output_names.append(f"gradient_{self.control_sequence_name}")
if bundle.include_hvp:
assert hvp_function is not None
outputs.append(hvp_function.outputs[0])
output_names.append(f"hvp_{self.control_sequence_name}")
if bundle.include_states:
outputs.append(
self.to_function(
include_states=True, name=f"{name}_states"
).outputs[1]
)
output_names.append("x_traj")
inputs = (
hvp_function.inputs
if hvp_function is not None
else cost_function.inputs
)
input_names = (
hvp_function.input_names
if hvp_function is not None
else cost_function.input_names
)
return Function(
name,
inputs,
outputs,
input_names=input_names,
output_names=tuple(output_names),
)
def _compiled_inputs(self) -> tuple[FunctionArg, ...]:
"""Return symbolic runtime inputs."""
dynamics = self._dynamics()
horizon = self._horizon()
x0 = _make_symbolic_like(
dynamics.inputs[0], self.initial_state_name
)
p = _make_symbolic_like(dynamics.inputs[2], self.parameter_name)
U = SXVector.sym(
self.control_sequence_name, horizon * self.control_size
)
if self.has_runtime_penalty_weight:
return x0, U, p, SX.sym(self.penalty_weight_name or "c")
return x0, U, p
[docs]
def stage_total_cost_function(self) -> Function:
"""Return the scalar stage cost including residual penalties.
The returned function has the same ``(x, u, p)`` signature as
:attr:`stage_cost` when ``penalty_weight`` is numeric, and
``(x, u, p, c)`` when ``penalty_weight`` is symbolic. When
:attr:`stage_penalty` is present its output is squared, summed,
multiplied by ``penalty_weight / 2``, and added to the base stage
cost.
Returns:
A scalar :class:`~gradgen.function.Function` for the effective
stage contribution used by primal, gradient, HVP, and Rust
code-generation paths.
"""
stage_cost = self._stage_cost()
x, u, p = stage_cost.inputs
penalty_weight = self._helper_penalty_weight_symbol()
inputs = (x, u, p)
input_names = stage_cost.input_names
if penalty_weight is not None:
inputs = (*inputs, penalty_weight)
input_names = (*input_names, self.penalty_weight_name or "c")
return Function(
f"{stage_cost.name}_with_penalty",
inputs,
(self._stage_total_cost(x, u, p, penalty_weight),),
input_names=input_names,
output_names=stage_cost.output_names,
)
[docs]
def terminal_total_cost_function(self) -> Function:
"""Return the scalar terminal cost including residual penalties.
The returned function has the same ``(x, p)`` signature as
:attr:`terminal_cost` when ``penalty_weight`` is numeric, and
``(x, p, c)`` when ``penalty_weight`` is symbolic. When
:attr:`terminal_penalty` is present its output is squared, summed,
multiplied by ``penalty_weight / 2``, and added to the base terminal
cost.
Returns:
A scalar :class:`~gradgen.function.Function` for the effective
terminal contribution used by primal, gradient, HVP, and Rust
code-generation paths.
"""
terminal_cost = self._terminal_cost()
x, p = terminal_cost.inputs
penalty_weight = self._helper_penalty_weight_symbol()
inputs = (x, p)
input_names = terminal_cost.input_names
if penalty_weight is not None:
inputs = (*inputs, penalty_weight)
input_names = (*input_names, self.penalty_weight_name or "c")
return Function(
f"{terminal_cost.name}_with_penalty",
inputs,
(self._terminal_total_cost(x, p, penalty_weight),),
input_names=input_names,
output_names=terminal_cost.output_names,
)
def _stage_total_cost(
self,
x: FunctionArg,
u: FunctionArg,
p: FunctionArg,
penalty_weight: SX | None = None,
) -> SX:
"""Return the symbolic stage cost including residual penalties."""
cost = _extract_scalar_output(self._stage_cost()(x, u, p))
if self.stage_penalty is None:
return cost
return cost + self._weighted_squared_norm(
_extract_single_output(self.stage_penalty(x, u, p)),
penalty_weight,
)
def _terminal_total_cost(
self,
x: FunctionArg,
p: FunctionArg,
penalty_weight: SX | None = None,
) -> SX:
"""Return the symbolic terminal cost including residual penalties."""
cost = _extract_scalar_output(self._terminal_cost()(x, p))
if self.terminal_penalty is None:
return cost
return cost + self._weighted_squared_norm(
_extract_single_output(self.terminal_penalty(x, p)),
penalty_weight,
)
def _weighted_squared_norm(
self,
residual: FunctionArg,
penalty_weight: SX | None = None,
) -> SX:
"""Return ``penalty_weight / 2 * ||residual||_2^2``."""
if self.penalty_weight is None:
raise ValueError("penalty_weight must be provided")
if penalty_weight is not None:
weight = penalty_weight
elif isinstance(self.penalty_weight, SX):
weight = self.penalty_weight
else:
weight = SX.const(float(self.penalty_weight))
return (SX.const(0.5) * weight) * _squared_norm(residual)
def _helper_penalty_weight_symbol(self) -> SX | None:
"""Return the helper input symbol for runtime penalty weights."""
if not self.has_runtime_penalty_weight:
return None
return SX.sym(self.penalty_weight_name or "c")
[docs]
def generate_rust(
self,
*,
config=None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Generate compact Rust for the primal total-cost kernel."""
self._require_complete()
return _generate_rust(
self,
config=config,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
[docs]
def create_rust_project(
self,
path: str,
*,
config=None,
crate_name: str | None = None,
function_name: str | None = None,
backend_mode: str = "std",
scalar_type: str = "f64",
):
"""Create a Rust crate containing the total-cost kernel."""
self._require_complete()
return _create_rust_project(
self,
path,
config=config,
crate_name=crate_name,
function_name=function_name,
backend_mode=backend_mode,
scalar_type=scalar_type,
)
def _single_shooting_arg_size(value: FunctionArg) -> int:
"""Return the flattened scalar dimension of ``value``."""
if isinstance(value, SX):
return 1
return len(value)
def _same_single_shooting_shape(left: FunctionArg, right: FunctionArg) -> bool:
"""Return ``True`` when two symbolic values share the same shape."""
if isinstance(left, SX) and isinstance(right, SX):
return True
if isinstance(left, SXVector) and isinstance(right, SXVector):
return len(left) == len(right)
return False
def _make_symbolic_like(value: FunctionArg, base_name: str) -> FunctionArg:
"""Create a fresh symbolic input with the same shape as ``value``."""
if isinstance(value, SX):
return SX.sym(base_name)
return SXVector.sym(base_name, len(value))
def _slice_packed_sequence(
sequence: SXVector,
stage_index: int,
block_size: int,
formal: FunctionArg,
) -> FunctionArg:
"""Return one stage block from a packed control sequence."""
start = stage_index * block_size
if isinstance(formal, SX):
return sequence[start]
return SXVector(sequence.elements[start: start + block_size])
def _extract_single_output(value: object) -> FunctionArg:
"""Normalize a single-output function call result."""
if isinstance(value, tuple):
if len(value) != 1:
raise ValueError("expected a single function output")
return value[0]
if isinstance(value, (SX, SXVector)):
return value
raise TypeError("single-shooting stages must return symbolic outputs")
def _extract_scalar_output(value: object) -> SX:
"""Normalize a single scalar output."""
output = _extract_single_output(value)
if not isinstance(output, SX):
raise ValueError(
"single-shooting cost functions must return scalar outputs"
)
return output
def _squared_norm(value: FunctionArg) -> SX:
"""Return the sum of squares for a scalar or vector symbolic value."""
if isinstance(value, SX):
return value * value
total = SX.const(0.0)
for element in value.elements:
total = total + element * element
return total
def _flatten_rollout_states(states: list[FunctionArg]) -> SXVector:
"""Flatten a rollout state sequence into one packed vector."""
scalars: list[SX] = []
for state in states:
if isinstance(state, SX):
scalars.append(state)
else:
scalars.extend(state.elements)
return SXVector(tuple(scalars))
def _single_shooting_primal_name(base_name: str, include_states: bool) -> str:
"""Return the default primal wrapper name."""
if include_states:
return f"{base_name}_with_states"
return base_name
def _single_shooting_joint_name(
base_name: str,
bundle: SingleShootingBundle,
control_sequence_name: str,
) -> str:
"""Return the default joint-kernel name."""
labels: list[str] = []
if bundle.include_cost:
labels.append("cost")
if bundle.include_gradient:
labels.append(f"gradient_{control_sequence_name}")
if bundle.include_hvp:
labels.append(f"hvp_{control_sequence_name}")
if bundle.include_states:
labels.append("states")
return f"{base_name}_{'_'.join(labels)}"
def _single_shooting_bundle_output_names(
problem: SingleShootingProblem,
bundle: SingleShootingBundle,
) -> tuple[str, ...]:
"""Return output names for a joint bundle."""
names: list[str] = []
if bundle.include_cost:
names.append("cost")
if bundle.include_gradient:
names.append(f"gradient_{problem.control_sequence_name}")
if bundle.include_hvp:
names.append(f"hvp_{problem.control_sequence_name}")
if bundle.include_states:
names.append("x_traj")
return tuple(names)
def _validate_single_shooting_bundle(bundle: SingleShootingBundle) -> None:
"""Validate a joint single-shooting bundle."""
if not (
bundle.include_cost or bundle.include_gradient or bundle.include_hvp
):
raise ValueError(
"SingleShootingBundle must request at least cost, gradient, or HVP"
)
if (
sum(
(
bundle.include_cost,
bundle.include_gradient,
bundle.include_hvp,
bundle.include_states,
)
)
< 2
):
raise ValueError(
"joint single-shooting kernels require at least two "
"requested outputs"
)
def _validate_single_shooting_dynamics(function: Function) -> None:
"""Validate the dynamics function signature."""
if len(function.inputs) != 3:
raise ValueError(
"SingleShootingProblem dynamics must accept (x, u, p)"
)
if len(function.outputs) != 1:
raise ValueError(
"SingleShootingProblem dynamics must return exactly one output"
)
if not _same_single_shooting_shape(
function.inputs[0], function.outputs[0]
):
raise ValueError(
"SingleShootingProblem dynamics must return the next state "
"with the same shape as x"
)
def _validate_single_shooting_stage_cost(function: Function) -> None:
"""Validate the stage-cost function signature."""
if len(function.inputs) != 3:
raise ValueError(
"SingleShootingProblem stage_cost must accept (x, u, p)"
)
if len(function.outputs) != 1 or not isinstance(function.outputs[0], SX):
raise ValueError(
"SingleShootingProblem stage_cost must return exactly one "
"scalar output"
)
def _validate_single_shooting_terminal_cost(function: Function) -> None:
"""Validate the terminal-cost function signature."""
if len(function.inputs) != 2:
raise ValueError(
"SingleShootingProblem terminal_cost must accept (x, p)"
)
if len(function.outputs) != 1 or not isinstance(function.outputs[0], SX):
raise ValueError(
"SingleShootingProblem terminal_cost must return exactly one "
"scalar output"
)
def _validate_single_shooting_penalty_configuration(
stage_penalty: Function | None,
terminal_penalty: Function | None,
penalty_weight: float | SX | None,
) -> None:
"""Validate optional residual penalty fields are supplied together."""
has_any_penalty = stage_penalty is not None or terminal_penalty is not None
if has_any_penalty and (
stage_penalty is None
or terminal_penalty is None
or penalty_weight is None
):
raise ValueError(
"SingleShootingProblem penalties require stage_penalty, "
"terminal_penalty, and penalty_weight"
)
if isinstance(penalty_weight, SXVector):
raise TypeError("penalty_weight must be a scalar")
def _validate_single_shooting_stage_penalty(function: Function) -> None:
"""Validate the stage residual signature."""
if len(function.inputs) != 3:
raise ValueError(
"SingleShootingProblem stage_penalty must accept (x, u, p)"
)
if len(function.outputs) != 1:
raise ValueError(
"SingleShootingProblem stage_penalty must return exactly one "
"residual output"
)
if not isinstance(function.outputs[0], (SX, SXVector)):
raise ValueError(
"SingleShootingProblem stage_penalty must return a scalar or "
"vector symbolic residual"
)
def _validate_single_shooting_terminal_penalty(function: Function) -> None:
"""Validate the terminal residual signature."""
if len(function.inputs) != 2:
raise ValueError(
"SingleShootingProblem terminal_penalty must accept (x, p)"
)
if len(function.outputs) != 1:
raise ValueError(
"SingleShootingProblem terminal_penalty must return exactly one "
"residual output"
)
if not isinstance(function.outputs[0], (SX, SXVector)):
raise ValueError(
"SingleShootingProblem terminal_penalty must return a scalar or "
"vector symbolic residual"
)
def _validate_single_shooting_shapes(
dynamics: Function,
stage_cost: Function,
terminal_cost: Function,
stage_penalty: Function | None = None,
terminal_penalty: Function | None = None,
) -> None:
"""Validate stage functions agree on state, control, and parameters."""
if not _same_single_shooting_shape(
dynamics.inputs[0], stage_cost.inputs[0]
):
raise ValueError(
"stage_cost x input must have the same shape as dynamics x"
)
if not _same_single_shooting_shape(
dynamics.inputs[1], stage_cost.inputs[1]
):
raise ValueError(
"stage_cost u input must have the same shape as dynamics u"
)
if not _same_single_shooting_shape(
dynamics.inputs[2], stage_cost.inputs[2]
):
raise ValueError(
"stage_cost p input must have the same shape as dynamics p"
)
if not _same_single_shooting_shape(
dynamics.outputs[0], terminal_cost.inputs[0]
):
raise ValueError(
"terminal_cost x input must have the same shape as the "
"dynamics state"
)
if not _same_single_shooting_shape(
dynamics.inputs[2], terminal_cost.inputs[1]
):
raise ValueError(
"terminal_cost p input must have the same shape as dynamics p"
)
if stage_penalty is not None:
if not _same_single_shooting_shape(
dynamics.inputs[0], stage_penalty.inputs[0]
):
raise ValueError(
"stage_penalty x input must have the same shape as "
"dynamics x"
)
if not _same_single_shooting_shape(
dynamics.inputs[1], stage_penalty.inputs[1]
):
raise ValueError(
"stage_penalty u input must have the same shape as "
"dynamics u"
)
if not _same_single_shooting_shape(
dynamics.inputs[2], stage_penalty.inputs[2]
):
raise ValueError(
"stage_penalty p input must have the same shape as "
"dynamics p"
)
if terminal_penalty is not None:
if not _same_single_shooting_shape(
dynamics.outputs[0], terminal_penalty.inputs[0]
):
raise ValueError(
"terminal_penalty x input must have the same shape as the "
"dynamics state"
)
if not _same_single_shooting_shape(
dynamics.inputs[2], terminal_penalty.inputs[1]
):
raise ValueError(
"terminal_penalty p input must have the same shape as "
"dynamics p"
)