Skip to main content
Oracles provide predictions (structure, embeddings, etc.) that energy terms use to compute scores. This guide shows you how to add your own.

The Oracle base class

All oracles inherit from Oracle and must implement a predict() method that takes a list of chains and returns an OracleResult:
from bagel.oracles.base import Oracle, OracleResult
from bagel.chain import Chain


class MyOracle(Oracle):
    result_class = MyResult  # must be set to your result class

    def predict(self, chains: list[Chain]) -> OracleResult:
        # Run your model and return results
        ...
Key points:
  • result_class must be set to your custom result class (a subclass of OracleResult). This enables type checking when energy terms access results.
  • Oracles are not copied during system copies — they are shared references. This avoids duplicating heavy model weights in memory.

Defining a result class

Result classes use Pydantic BaseModel and must implement save_attributes():
import pathlib as pl
from bagel.oracles.base import OracleResult
from bagel.chain import Chain


class MyResult(OracleResult):
    input_chains: list[Chain]
    my_score: float
    my_array: list[float]

    class Config:
        arbitrary_types_allowed = True  # needed for numpy arrays or biotite types

    def save_attributes(self, filepath: pl.Path) -> None:
        """Save result data to disk (called by loggers)."""
        with open(f"{filepath}.score.txt", "w") as f:
            f.write(str(self.my_score))
The input_chains field is inherited from OracleResult and is always required — it records which chains were passed to the oracle.

Implementing predict()

The predict() method receives a list of Chain objects and must return an instance of your result class:
class MyOracle(Oracle):
    result_class = MyResult

    def __init__(self, model_path: str):
        self.model = load_model(model_path)

    def predict(self, chains: list[Chain]) -> MyResult:
        # Extract sequences from chains
        sequences = [chain.sequence for chain in chains]

        # Run your model
        score = self.model.score(sequences)

        return MyResult(
            input_chains=chains,
            my_score=score,
            my_array=[0.0] * len(sequences),
        )

Folding vs embedding oracles

For common prediction types, BAGEL provides specialized base classes with more structure:

FoldingOracle

For models that predict 3D structures. Implement fold() instead of predict():
from bagel.oracles.folding import FoldingOracle, FoldingResult
from bagel.chain import Chain
from biotite.structure import AtomArray


class MyFoldingOracle(FoldingOracle):
    def fold(self, chains: list[Chain]) -> FoldingResult:
        # Run your structure prediction model
        structure: AtomArray = run_prediction(chains)

        return FoldingResult(
            input_chains=chains,
            structure=structure,  # biotite AtomArray with 3D coordinates
            # Add any additional metrics your model provides
        )
FoldingOracle automatically routes predict() to fold(). Energy terms access the structure via oracles_result.get_structure(oracle).

EmbeddingOracle

For language models that produce per-residue embeddings. Implement embed(), _pre_process(), and _post_process():
import numpy as np
from bagel.oracles.embedding import EmbeddingOracle, EmbeddingResult
from bagel.chain import Chain


class MyEmbeddingOracle(EmbeddingOracle):
    def embed(self, chains: list[Chain]) -> EmbeddingResult:
        processed = self._pre_process(chains)
        output = self.model(processed)
        return self._post_process(output)

    def _pre_process(self, chains: list[Chain]):
        """Convert chains to model input format."""
        return [chain.sequence for chain in chains]

    def _post_process(self, output) -> EmbeddingResult:
        """Convert model output to EmbeddingResult."""
        return EmbeddingResult(
            input_chains=self._current_chains,
            embeddings=np.array(output),  # shape: (total_residues, embedding_dim)
        )
Energy terms access embeddings via oracles_result.get_embeddings(oracle).

Example: a custom oracle

Here is a simple property oracle that computes a solubility score based on sequence composition:
import pathlib as pl
import numpy as np
from bagel.oracles.base import Oracle, OracleResult
from bagel.chain import Chain


# Amino acid solubility contributions (simplified)
SOLUBILITY_SCORES = {
    "D": 1.0, "E": 1.0, "K": 0.9, "R": 0.8, "N": 0.7, "Q": 0.7,
    "S": 0.5, "T": 0.5, "H": 0.4, "G": 0.3, "A": 0.2, "P": 0.2,
    "Y": -0.1, "W": -0.3, "F": -0.5, "M": -0.4,
    "V": -0.6, "I": -0.7, "L": -0.7, "C": -0.2,
}


class SolubilityResult(OracleResult):
    input_chains: list[Chain]
    per_residue_scores: list[float]
    overall_score: float

    def save_attributes(self, filepath: pl.Path) -> None:
        np.savetxt(f"{filepath}.solubility.txt", self.per_residue_scores)


class SolubilityOracle(Oracle):
    """Predicts sequence solubility from amino acid composition."""

    result_class = SolubilityResult

    def predict(self, chains: list[Chain]) -> SolubilityResult:
        per_residue = []
        for chain in chains:
            for residue in chain.residues:
                per_residue.append(SOLUBILITY_SCORES.get(residue.name, 0.0))

        return SolubilityResult(
            input_chains=chains,
            per_residue_scores=per_residue,
            overall_score=float(np.mean(per_residue)),
        )
You can then write a custom energy term that uses this oracle:
from bagel.energies import EnergyTerm
from bagel.oracles import OraclesResultDict


class SolubilityEnergy(EnergyTerm):
    def __init__(self, oracle: SolubilityOracle, weight: float = 1.0):
        super().__init__(
            name="SolubilityEnergy",
            oracle=oracle,
            inheritable=True,
            weight=weight,
        )

    def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
        result = oracles_result[self.oracle]
        # Invert so that higher solubility = lower energy
        unweighted = -result.overall_score
        return unweighted, unweighted * self.weight