# boileroom/models/{name}/core.py
import logging
from typing import Optional, Union, Sequence
import torch
from ...base import FoldingAlgorithm
from ...utils import Timer, MODAL_MODEL_DIR
from .types import MyModelOutput
logger = logging.getLogger(__name__)
class MyModelCore(FoldingAlgorithm):
"""Core algorithm for MyModel structure prediction."""
DEFAULT_CONFIG = {
"device": "cuda:0",
"include_fields": None, # Optional[List[str]]
# Add model-specific defaults here
}
# Keys that cannot be overridden per-call via options
STATIC_CONFIG_KEYS = {"device"}
def __init__(self, config: dict | None = None) -> None:
super().__init__(config or {})
self.name = "MyModel"
self.version = "1.0.0"
self.metadata = self._initialize_metadata(
model_name=self.name,
model_version=self.version,
)
self.model = None
self.tokenizer = None
def _initialize(self) -> None:
"""Entry point called by the Modal wrapper after construction."""
self._load()
def _load(self) -> None:
"""Load model weights and move to the resolved device."""
device = self._resolve_device()
# Load your model here
# self.model = MyModelLib.from_pretrained(...)
# self.model = self.model.to(device)
# self.model.eval()
self.ready = True
def fold(
self,
sequences: Union[str, Sequence[str]],
options: Optional[dict] = None,
) -> MyModelOutput:
"""Run structure prediction."""
sequences = self._validate_sequences(sequences)
config = self._merge_options(options)
with Timer("Preprocessing") as preprocess_timer:
# Tokenize, prepare inputs
pass
with Timer("Inference") as inference_timer:
# Run model inference
pass
with Timer("Postprocessing") as postprocess_timer:
output = self._convert_outputs(
raw_output=...,
sequences=sequences,
config=config,
)
output.metadata.sequence_lengths = self._compute_sequence_lengths(sequences)
output.metadata.preprocessing_time = preprocess_timer.duration
output.metadata.inference_time = inference_timer.duration
output.metadata.postprocessing_time = postprocess_timer.duration
return self._filter_include_fields(output, config.get("include_fields"))
def _convert_outputs(self, raw_output, sequences, config) -> MyModelOutput:
"""Convert raw model output into the typed dataclass."""
metadata = self._initialize_metadata(self.name, self.version)
return MyModelOutput(
metadata=metadata,
atom_array=..., # Convert to list of Biotite AtomArray
)