Definition
The embedding similarity energy measures how much the current per-residue embeddings deviate from reference embeddings, using cosine similarity.
Eembedding=1−NG1α∈G∑cos(eα,eαref)
- eα is the current embedding vector for residue α (L2-normalized)
- eαref is the reference embedding vector for residue α (L2-normalized)
- cos(⋅,⋅) is the cosine similarity
- NG is the number of residues in the group G
Parameters
The oracle that will be used to calculate the embeddings.
Which residues to include in the calculation.
reference_embeddings
npt.NDArray[np.float64]
required
The reference embeddings to compare to.
The weight of the energy term.
Optional name to append to the energy term name.
Methods
compute
Parameters
oracles_result
OraclesResultDict
required
conserved_index_list
Returns the indices of the conserved residues (stored in .residue_group[0]) in the pLM embedding array.
Parameters
Example
import numpy as np
import bagel as bg
# Define residues with some conserved (immutable) positions
residues = [
bg.Residue(name=aa, chain_ID="A", index=i, mutable=(i < 20 or i > 40))
for i, aa in enumerate(sequence)
]
conserved_residues = [r for r in residues if not r.mutable]
chain = bg.Chain(residues=residues)
# Create the embedding oracle and extract reference embeddings
esm2 = bg.oracles.ESM2(config={"model_name": "esm2_t33_650M_UR50D"})
result = esm2.embed(chains=[chain])
immutable_mask = ~np.array([r.mutable for r in residues])
reference_embeddings = result.embeddings[immutable_mask]
# Maintain embedding similarity at conserved positions
embedding_energy = bg.energies.EmbeddingsSimilarityEnergy(
oracle=esm2,
residues=conserved_residues,
reference_embeddings=reference_embeddings,
weight=1.0,
)
# Add to a state
state = bg.State(
chains=[chain],
energy_terms=[embedding_energy],
name="my_state",
)