Skip to main content
Energy terms define what BAGEL optimizes for. This guide shows you how to create your own.

The EnergyTerm base class

All energy terms inherit from EnergyTerm (defined in bagel/energies.py). The base class handles weight management, oracle association, residue group tracking, and integration with the grand-canonical ensemble. Your __init__ must call the parent constructor with these arguments:
from bagel.energies import EnergyTerm

class MyEnergy(EnergyTerm):
    def __init__(self, oracle, weight=1.0, residues=None):
        super().__init__(
            name="MyEnergy",          # unique identifier
            oracle=oracle,            # which oracle provides data
            inheritable=True,         # whether new residues inherit this term (GrandCanonical)
            weight=weight,
        )
        # Set residue groups from the residues argument
        if residues is not None:
            from bagel.energies import residue_list_to_group
            self.residue_groups = [residue_list_to_group(group) for group in residues]
Key attributes:
  • name — a string identifier, used in logging and output files
  • oracle — the oracle instance that provides predictions for this term
  • inheritable — if True, newly inserted residues (in GrandCanonical mode) inherit this energy term from their neighbors. Set to False for terms like TemplateMatchEnergy where new residues would be ill-defined
  • weight — multiplier applied to the unweighted energy
  • residue_groups — list of ResidueGroup tuples, each containing arrays of chain IDs and residue indices

Implementing compute()

The only method you must implement is compute(). It receives an OraclesResultDict mapping oracles to their results and must return a tuple of (unweighted_energy, weighted_energy):
from bagel.oracles import OraclesResultDict

def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
    # Get the result from your oracle
    result = oracles_result[self.oracle]

    # Calculate your energy metric
    unweighted = self._calculate_metric(result)

    # Return (unweighted, weighted)
    return unweighted, unweighted * self.weight
Guidelines:
  • Unweighted energy should be normalized to the 0–1 range where possible (0 = best, 1 = worst)
  • Weighted energy is unweighted * self.weight — this is what gets summed into the state energy
  • The method is called once per optimization step, after all oracles have produced their results

Working with oracle results

The OraclesResultDict provides typed access to oracle outputs:
def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
    # For a FoldingOracle — get the predicted structure
    structure = oracles_result.get_structure(self.oracle)  # biotite AtomArray

    # For an EmbeddingOracle — get per-residue embeddings
    embeddings = oracles_result.get_embeddings(self.oracle)  # NDArray

    # For any oracle — get the raw result object
    result = oracles_result[self.oracle]
    # Access oracle-specific attributes (e.g., result.ptm, result.pae)

Residue groups

Residue groups define which residues an energy term operates on. The helper methods get_residue_mask() and get_atom_mask() let you filter structure data to your groups:
def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
    structure = oracles_result.get_structure(self.oracle)

    # Get a boolean mask over residues for your first group
    residue_mask = self.get_residue_mask(structure, residue_group_index=0)

    # Get a boolean mask over atoms for your first group
    atom_mask = self.get_atom_mask(structure, residue_group_index=0)

    # Use masks to filter structure data
    group_coords = structure.coord[atom_mask]
    # ...
For grand-canonical support (when residues can be inserted/deleted), you may need to implement remove_residue() and add_residue() to update internal state when residue indices shift. The base class handles index bookkeeping for residue_groups automatically.

Example: a custom energy term

Here is a complete example of a custom energy term that penalizes low per-residue pLDDT confidence on a specific residue group, but only for residues that are also surface-exposed (SASA above a threshold):
import numpy as np
from biotite.structure import sasa
from bagel.energies import EnergyTerm, residue_list_to_group
from bagel.oracles import OraclesResultDict
from bagel.oracles.folding import FoldingOracle


class SurfaceConfidenceEnergy(EnergyTerm):
    """Penalizes low pLDDT on surface-exposed residues in a group."""

    def __init__(self, oracle: FoldingOracle, residues: list, weight: float = 1.0,
                 sasa_threshold: float = 0.2):
        super().__init__(
            name="SurfaceConfidenceEnergy",
            oracle=oracle,
            inheritable=True,
            weight=weight,
        )
        self.residue_groups = [residue_list_to_group(residues)]
        self.sasa_threshold = sasa_threshold

    def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
        structure = oracles_result.get_structure(self.oracle)
        result = oracles_result[self.oracle]

        # Get pLDDT values and residue mask for our group
        residue_mask = self.get_residue_mask(structure, residue_group_index=0)
        atom_mask = self.get_atom_mask(structure, residue_group_index=0)
        plddt = result.local_plddt[0][residue_mask]

        # Compute SASA for the full structure, then filter to our group
        atom_sasa = sasa(structure)
        group_sasa = atom_sasa[atom_mask]

        # Identify surface-exposed residues (above threshold)
        # Average SASA per residue using CA atoms as proxy
        ca_mask = structure.atom_name[atom_mask] == "CA"
        surface_mask = group_sasa[ca_mask] > self.sasa_threshold

        if surface_mask.sum() == 0:
            return 0.0, 0.0

        # Energy = negative mean pLDDT of surface residues (lower pLDDT = higher energy)
        unweighted = -float(np.mean(plddt[surface_mask]))
        return unweighted, unweighted * self.weight
Usage:
import bagel as bg

esmfold = bg.oracles.ESMFold(use_modal=True)

energy = SurfaceConfidenceEnergy(
    oracle=esmfold,
    residues=binder_residues,
    weight=2.0,
    sasa_threshold=0.3,
)