import pathlib as pl
from bagel.callbacks import Callback, CallbackContext
class CheckpointCallback(Callback):
"""Saves the best system's sequences to a FASTA file every N steps."""
def __init__(self, checkpoint_interval: int = 100, output_dir: str = "checkpoints"):
self.checkpoint_interval = checkpoint_interval
self.output_dir = pl.Path(output_dir)
def on_optimization_start(self, context: CallbackContext) -> None:
self.output_dir.mkdir(parents=True, exist_ok=True)
def on_step_end(self, context: CallbackContext) -> None:
if context.step % self.checkpoint_interval != 0:
return
# Save best system sequences
filepath = self.output_dir / f"checkpoint_step_{context.step}.fasta"
with open(filepath, "w") as f:
for state in context.best_system.states:
f.write(f">{state.name}_step{context.step}\n")
f.write(f"{':'.join(state.total_sequence)}\n")
# Log energy progress
energy = context.metrics.get("best_system_energy", float("inf"))
print(f"Step {context.step}: best energy = {energy:.4f}")
def on_optimization_end(self, context: CallbackContext) -> None:
# Save final summary
filepath = self.output_dir / "final_summary.txt"
with open(filepath, "w") as f:
f.write(f"Final best energy: {context.metrics['best_system_energy']:.4f}\n")
f.write(f"Total steps: {context.step}\n")
for state in context.best_system.states:
f.write(f"\n{state.name}:\n")
f.write(f" Sequence: {':'.join(state.total_sequence)}\n")