Model#

Main Model#

Heads#

class alphagenome_pytorch.heads.ContactMapsHead(*args, **kwargs)[source]#

Bases: Module

Predicts contact maps from pairwise embeddings. JAX: alphagenome_research.model.heads.ContactMapsHead

forward(pair_embeddings, organism_index, channels_last=True)[source]#
class alphagenome_pytorch.heads.GenomeTracksHead(*args, **kwargs)[source]#

Bases: Module

Predicts genome tracks at multiple resolutions.

Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, T) to match JAX reference.

Matches JAX: alphagenome_research.model.heads.GenomeTracksHead

Parameters:

init_scheme (Literal['truncated_normal', 'uniform'])

forward(embeddings_dict, organism_index, return_scaled=False, channels_last=True)[source]#

Returns predictions in experimental or model scale.

Parameters:
  • embeddings_dict – Dict mapping resolution to embeddings (B, C, S) NCL

  • organism_index – Organism indices (B,)

  • return_scaled – If True, return model space (for loss). If False, return experimental space (for inference).

  • channels_last – Output format. - True (default): NLC format (B, S, T) - user-friendly, matches JAX - False: NCL format (B, T, S) - for training efficiency (0 transposes)

Returns:

Dict mapping resolution to predictions in specified format

scale(x, organism_index, resolution, channels_last=True)[source]#

Scales targets from experimental to model prediction space.

Parameters:
  • x – Targets in experimental space. NLC (B, S, T) if channels_last else NCL (B, T, S)

  • organism_index – Organism indices (B,)

  • resolution – Resolution (1 or 128)

  • channels_last – If True, x is NLC. If False, x is NCL.

Returns:

Targets in model space (same format as input)

unscale(x, organism_index, resolution, channels_last=True)[source]#

Unscales predictions to experimental data scale.

class alphagenome_pytorch.heads.MultiOrganismConv1d(*args, **kwargs)[source]#

Bases: Module

Organism-specific 1x1 conv for NCL format (B, C, S).

Equivalent to JAX _MultiOrganismLinear which operates on NLC format. Using Conv1d avoids transpose overhead when data is already NCL.

Parameters:

init_scheme (Literal['truncated_normal', 'uniform'])

forward(x, organism_index)[source]#
reset_parameters()[source]#
class alphagenome_pytorch.heads.MultiOrganismLinear(*args, **kwargs)[source]#

Bases: Module

Linear layer with organism-specific weights. Expects NLC format (B, S, C).

Used for non-sequence operations like ContactMapsHead on pair activations. JAX: alphagenome_research.model.heads._MultiOrganismLinear

Parameters:

init_scheme (Literal['truncated_normal', 'uniform'])

forward(x, organism_index)[source]#
reset_parameters()[source]#
class alphagenome_pytorch.heads.SpliceSitesClassificationHead(*args, **kwargs)[source]#

Bases: Module

Predicts splice site classification.

Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, 5) to match JAX reference.

Classes: Donor+, Acceptor+, Donor-, Acceptor-, Not a splice site JAX: alphagenome_research.model.heads.SpliceSitesClassificationHead

forward(embeddings_1bp, organism_index, channels_last=True)[source]#
class alphagenome_pytorch.heads.SpliceSitesJunctionHead(*args, **kwargs)[source]#

Bases: Module

Predicts splice junction read counts. Expects NCL format (B, C, S).

JAX: alphagenome_research.model.heads.SpliceSitesJunctionHead

forward(embeddings_1bp, organism_index, channels_last=True, **kwargs)[source]#
Parameters:
  • embeddings_1bp – (B, C, S) - NCL format

  • organism_index – (B,)

  • splice_site_positions – (B, 4, P) - required kwarg

Returns:

Dict with pred_counts (B, P, P, 2*T), positions, mask

class alphagenome_pytorch.heads.SpliceSitesUsageHead(*args, **kwargs)[source]#

Bases: Module

Predicts splice site usage.

Internal computation is NCL for Conv1d efficiency. Outputs NLC format (B, S, T) to match JAX reference.

Outputs proportion of RNA using each splice site. JAX: alphagenome_research.model.heads.SpliceSitesUsageHead

forward(embeddings_1bp, organism_index, channels_last=True)[source]#
alphagenome_pytorch.heads.predictions_scaling(x, track_means, resolution, apply_squashing, soft_clip_value=10.0, channels_last=True)[source]#

Scales predictions to experimental data scale.

Matches JAX: alphagenome_research.model.heads.predictions_scaling

Parameters:
  • x (torch.Tensor) – Model predictions - NLC (B, S, C) if channels_last else NCL (B, C, S)

  • track_means (torch.Tensor) – Mean values per track (B, C)

  • resolution (int) – Bin resolution (1 or 128)

  • apply_squashing (bool) – Whether to apply power law expansion (for RNA-seq)

  • soft_clip_value (float) – Value for soft clipping

  • channels_last (bool) – If True, x is NLC. If False, x is NCL.

Returns:

Scaled predictions in experimental data space (same format as input)

Return type:

torch.Tensor

alphagenome_pytorch.heads.targets_scaling(x, track_means, resolution, apply_squashing, soft_clip_value=10.0, channels_last=True)[source]#

Scales targets from experimental data to model prediction space.

Inverse of predictions_scaling. Used to scale targets before loss computation. Matches JAX: alphagenome_research.model.heads.targets_scaling

Parameters:
  • x (torch.Tensor) – Targets in experimental space - NLC (B, S, C) if channels_last else NCL (B, C, S)

  • track_means (torch.Tensor) – Per-track scaling factors (B, C)

  • resolution (int) – Resolution multiplier (1 or 128)

  • apply_squashing (bool) – Apply power law compression (only for RNA-seq)

  • soft_clip_value (float) – Value for soft clipping

  • channels_last (bool) – If True, x is NLC. If False, x is NCL.

Returns:

Targets in model space (same format as input)

Return type:

torch.Tensor

Embeddings#

class alphagenome_pytorch.embeddings.OutputEmbedder(*args, **kwargs)[source]#

Bases: Module

Output embedder using Conv1d for NCL format (B, C, S).

Matches JAX alphagenome_research.model.embeddings.OutputEmbedder.

Logic: 1. Conv1d projection to output channels. 2. Optional skip connection addition (with projection if needed). 3. Add Organism Embedding. 4. Norm + GELU.

forward(x, organism_index, skip_x=None, channels_last=False)[source]#
class alphagenome_pytorch.embeddings.OutputPair(*args, **kwargs)[source]#

Bases: Module

Output embedder for pair activations (B, S, S, D).

Note: Pair activations use a different format than sequence data. LayerNorm operates over the last dimension (features).

forward(x, organism_index)[source]#