Skip to main content

Definition

The embedding similarity energy measures how much the current per-residue embeddings deviate from reference embeddings, using cosine similarity. Eembedding=11NGαGcos(eα,eαref)E_{\mathrm{embedding}} = 1 - \frac{1}{N_G} \sum_{\alpha \in G} \cos(\mathbf{e}_\alpha, \mathbf{e}_\alpha^{\mathrm{ref}})
  • eα\mathbf{e}_\alpha is the current embedding vector for residue α\alpha (L2-normalized)
  • eαref\mathbf{e}_\alpha^{\mathrm{ref}} is the reference embedding vector for residue α\alpha (L2-normalized)
  • cos(,)\cos(\cdot, \cdot) is the cosine similarity
  • NGN_G is the number of residues in the group GG

Parameters

oracle
EmbeddingOracle
required
The oracle that will be used to calculate the embeddings.
residues
list[Residue]
required
Which residues to include in the calculation.
reference_embeddings
npt.NDArray[np.float64]
required
The reference embeddings to compare to.
weight
float
default:"1.0"
The weight of the energy term.
name
str | None
default:"None"
Optional name to append to the energy term name.

Methods

compute

Parameters
oracles_result
OraclesResultDict
required

conserved_index_list

Returns the indices of the conserved residues (stored in .residue_group[0]) in the pLM embedding array. Parameters
chains
list[Chain]
required

Example

import numpy as np
import bagel as bg

# Define residues with some conserved (immutable) positions
residues = [
    bg.Residue(name=aa, chain_ID="A", index=i, mutable=(i < 20 or i > 40))
    for i, aa in enumerate(sequence)
]
conserved_residues = [r for r in residues if not r.mutable]
chain = bg.Chain(residues=residues)

# Create the embedding oracle and extract reference embeddings
esm2 = bg.oracles.ESM2(config={"model_name": "esm2_t33_650M_UR50D"})
result = esm2.embed(chains=[chain])
immutable_mask = ~np.array([r.mutable for r in residues])
reference_embeddings = result.embeddings[immutable_mask]

# Maintain embedding similarity at conserved positions
embedding_energy = bg.energies.EmbeddingsSimilarityEnergy(
    oracle=esm2,
    residues=conserved_residues,
    reference_embeddings=reference_embeddings,
    weight=1.0,
)

# Add to a state
state = bg.State(
    chains=[chain],
    energy_terms=[embedding_energy],
    name="my_state",
)