import numpy as np
from dataclasses import dataclass, field
from typing import Dict
from bagel.mutation import MutationProtocol, Mutation, MutationRecord
from bagel.system import System
from bagel.constants import mutation_bias_no_cystein
CHARGED_AAS = {"D", "E", "K", "R"}
HYDROPHOBIC_AAS = {"A", "I", "L", "M", "F", "V", "W"}
@dataclass
class PositionBiasedProtocol(MutationProtocol):
"""Biases mutations based on residue position in the chain."""
bias_strength: float = 3.0 # how much to upweight favored AAs
def one_step(self, system: System) -> tuple[System, MutationRecord]:
mutated_system = system.__copy__()
mutations: list[Mutation] = []
for _ in range(self.n_mutations):
chain = self.choose_chain(mutated_system)
index = np.random.choice(chain.mutable_residue_indexes)
# Determine position bias
relative_pos = index / max(chain.length - 1, 1)
if relative_pos < 0.3:
favored = CHARGED_AAS
elif relative_pos > 0.7:
favored = HYDROPHOBIC_AAS
else:
favored = set()
# Build biased probability distribution
aa_keys = list(self.mutation_bias.keys())
probs = np.array([self.mutation_bias[a] for a in aa_keys], dtype=float)
# Upweight favored amino acids
for i, aa in enumerate(aa_keys):
if aa in favored:
probs[i] *= self.bias_strength
# Exclude current AA
current_aa = chain.residues[index].name
if self.exclude_self:
probs[aa_keys.index(current_aa)] = 0.0
probs /= probs.sum()
new_aa = np.random.choice(aa_keys, p=probs)
chain.mutate_residue(index=index, amino_acid=new_aa)
mutations.append(Mutation(
chain_id=chain.chain_ID,
move_type="substitution",
residue_index=index,
old_amino_acid=current_aa,
new_amino_acid=new_aa,
))
return mutated_system, MutationRecord(mutations=mutations)