Model#
Main Model#
Heads#
- class alphagenome_pytorch.heads.ContactMapsHead(*args, **kwargs)[source]#
Bases:
ModulePredicts contact maps from pairwise embeddings. JAX: alphagenome_research.model.heads.ContactMapsHead
- class alphagenome_pytorch.heads.GenomeTracksHead(*args, **kwargs)[source]#
Bases:
ModulePredicts genome tracks at multiple resolutions.
Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, T) to match JAX reference.
Matches JAX: alphagenome_research.model.heads.GenomeTracksHead
- Parameters:
init_scheme (Literal['truncated_normal', 'uniform'])
- forward(embeddings_dict, organism_index, return_scaled=False, channels_last=True)[source]#
Returns predictions in experimental or model scale.
- Parameters:
embeddings_dict – Dict mapping resolution to embeddings (B, C, S) NCL
organism_index – Organism indices (B,)
return_scaled – If True, return model space (for loss). If False, return experimental space (for inference).
channels_last – Output format. - True (default): NLC format (B, S, T) - user-friendly, matches JAX - False: NCL format (B, T, S) - for training efficiency (0 transposes)
- Returns:
Dict mapping resolution to predictions in specified format
- scale(x, organism_index, resolution, channels_last=True)[source]#
Scales targets from experimental to model prediction space.
- Parameters:
x – Targets in experimental space. NLC (B, S, T) if channels_last else NCL (B, T, S)
organism_index – Organism indices (B,)
resolution – Resolution (1 or 128)
channels_last – If True, x is NLC. If False, x is NCL.
- Returns:
Targets in model space (same format as input)
- class alphagenome_pytorch.heads.MultiOrganismConv1d(*args, **kwargs)[source]#
Bases:
ModuleOrganism-specific 1x1 conv for NCL format (B, C, S).
Equivalent to JAX _MultiOrganismLinear which operates on NLC format. Using Conv1d avoids transpose overhead when data is already NCL.
- Parameters:
init_scheme (Literal['truncated_normal', 'uniform'])
- class alphagenome_pytorch.heads.MultiOrganismLinear(*args, **kwargs)[source]#
Bases:
ModuleLinear layer with organism-specific weights. Expects NLC format (B, S, C).
Used for non-sequence operations like ContactMapsHead on pair activations. JAX: alphagenome_research.model.heads._MultiOrganismLinear
- Parameters:
init_scheme (Literal['truncated_normal', 'uniform'])
- class alphagenome_pytorch.heads.SpliceSitesClassificationHead(*args, **kwargs)[source]#
Bases:
ModulePredicts splice site classification.
Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, 5) to match JAX reference.
Classes: Donor+, Acceptor+, Donor-, Acceptor-, Not a splice site JAX: alphagenome_research.model.heads.SpliceSitesClassificationHead
- class alphagenome_pytorch.heads.SpliceSitesJunctionHead(*args, **kwargs)[source]#
Bases:
ModulePredicts splice junction read counts. Expects NCL format (B, C, S).
JAX: alphagenome_research.model.heads.SpliceSitesJunctionHead
- class alphagenome_pytorch.heads.SpliceSitesUsageHead(*args, **kwargs)[source]#
Bases:
ModulePredicts splice site usage.
Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, T) to match JAX reference.
Outputs proportion of RNA using each splice site. JAX: alphagenome_research.model.heads.SpliceSitesUsageHead
- alphagenome_pytorch.heads.predictions_scaling(x, track_means, resolution, apply_squashing, soft_clip_value=10.0, channels_last=True)[source]#
Scales predictions to experimental data scale.
Matches JAX: alphagenome_research.model.heads.predictions_scaling
- Parameters:
x (torch.Tensor) – Model predictions - NLC (B, S, C) if channels_last else NCL (B, C, S)
track_means (torch.Tensor) – Mean values per track (B, C)
resolution (int) – Bin resolution (1 or 128)
apply_squashing (bool) – Whether to apply power law expansion (for RNA-seq)
soft_clip_value (float) – Value for soft clipping
channels_last (bool) – If True, x is NLC. If False, x is NCL.
- Returns:
Scaled predictions in experimental data space (same format as input)
- Return type:
- alphagenome_pytorch.heads.targets_scaling(x, track_means, resolution, apply_squashing, soft_clip_value=10.0, channels_last=True)[source]#
Scales targets from experimental data to model prediction space.
Inverse of predictions_scaling. Used to scale targets before loss computation. Matches JAX: alphagenome_research.model.heads.targets_scaling
- Parameters:
x (torch.Tensor) – Targets in experimental space - NLC (B, S, C) if channels_last else NCL (B, C, S)
track_means (torch.Tensor) – Per-track scaling factors (B, C)
resolution (int) – Resolution multiplier (1 or 128)
apply_squashing (bool) – Apply power law compression (only for RNA-seq)
soft_clip_value (float) – Value for soft clipping
channels_last (bool) – If True, x is NLC. If False, x is NCL.
- Returns:
Targets in model space (same format as input)
- Return type:
Embeddings#
- class alphagenome_pytorch.embeddings.OutputEmbedder(*args, **kwargs)[source]#
Bases:
ModuleOutput embedder using Conv1d for NCL format (B, C, S).
Matches JAX alphagenome_research.model.embeddings.OutputEmbedder.
Logic: 1. Conv1d projection to output channels. 2. Optional skip connection addition (with projection if needed). 3. Add Organism Embedding. 4. Norm + GELU.