import numpy as np
from biotite.structure import sasa
from bagel.energies import EnergyTerm, residue_list_to_group
from bagel.oracles import OraclesResultDict
from bagel.oracles.folding import FoldingOracle
class SurfaceConfidenceEnergy(EnergyTerm):
"""Penalizes low pLDDT on surface-exposed residues in a group."""
def __init__(self, oracle: FoldingOracle, residues: list, weight: float = 1.0,
sasa_threshold: float = 0.2):
super().__init__(
name="SurfaceConfidenceEnergy",
oracle=oracle,
inheritable=True,
weight=weight,
)
self.residue_groups = [residue_list_to_group(residues)]
self.sasa_threshold = sasa_threshold
def compute(self, oracles_result: OraclesResultDict) -> tuple[float, float]:
structure = oracles_result.get_structure(self.oracle)
result = oracles_result[self.oracle]
# Get pLDDT values and residue mask for our group
residue_mask = self.get_residue_mask(structure, residue_group_index=0)
atom_mask = self.get_atom_mask(structure, residue_group_index=0)
plddt = result.local_plddt[0][residue_mask]
# Compute SASA for the full structure, then filter to our group
atom_sasa = sasa(structure)
group_sasa = atom_sasa[atom_mask]
# Identify surface-exposed residues (above threshold)
# Average SASA per residue using CA atoms as proxy
ca_mask = structure.atom_name[atom_mask] == "CA"
surface_mask = group_sasa[ca_mask] > self.sasa_threshold
if surface_mask.sum() == 0:
return 0.0, 0.0
# Energy = negative mean pLDDT of surface residues (lower pLDDT = higher energy)
unweighted = -float(np.mean(plddt[surface_mask]))
return unweighted, unweighted * self.weight