Layers#

Core Layers#

Convolutions#

class alphagenome_pytorch.convolutions.ConvBlock(*args, **kwargs)[source]#

Bases: Module

Convolution block operating on NCL format (B, C, S).

forward(x)[source]#
class alphagenome_pytorch.convolutions.DnaEmbedder(*args, **kwargs)[source]#

Bases: Module

Embeds one-hot DNA to feature space. Expects NCL format (B, 4, S).

forward(x)[source]#
class alphagenome_pytorch.convolutions.DownResBlock(*args, **kwargs)[source]#

Bases: Module

Downsampling residual block. Expects NCL format (B, C, S).

forward(x)[source]#
class alphagenome_pytorch.convolutions.StandardizedConv1d(*args, **kwargs)[source]#

Bases: Conv1d

1D Convolution with weight standardization and learned scaling.

Expects NCL format (B, C, S).

forward(x)[source]#
class alphagenome_pytorch.convolutions.UpResBlock(*args, **kwargs)[source]#

Bases: Module

Upsampling residual block with skip connection. Expects NCL format (B, C, S).

forward(x, unet_skip)[source]#

Attention#

class alphagenome_pytorch.attention.AttentionBiasBlock(*args, **kwargs)[source]#

Bases: Module

forward(x)[source]#
class alphagenome_pytorch.attention.MHABlock(*args, **kwargs)[source]#

Bases: Module

Multi-Head Attention block.

Matches JAX: alphagenome_research.model.attention.MHABlock

JAX uses precision=BF16_BF16_F32 for attention, meaning: - Inputs in bfloat16 - Accumulation in float32 - Output cast back to input dtype

forward(x, attention_bias, compute_dtype=None)[source]#
class alphagenome_pytorch.attention.MLPBlock(*args, **kwargs)[source]#

Bases: Module

forward(x)[source]#
class alphagenome_pytorch.attention.PairMLPBlock(*args, **kwargs)[source]#

Bases: Module

forward(x)[source]#
class alphagenome_pytorch.attention.PairUpdateBlock(*args, **kwargs)[source]#

Bases: Module

forward(x, pair_rep, compute_dtype=None)[source]#
class alphagenome_pytorch.attention.RowAttentionBlock(*args, **kwargs)[source]#

Bases: Module

Self-attention block applied along rows of pairwise representations.

Matches JAX: alphagenome_research.model.attention.RowAttentionBlock

JAX uses precision=BF16_BF16_F32 for einsum operations.

forward(x, compute_dtype=None)[source]#
class alphagenome_pytorch.attention.SequenceToPairBlock(*args, **kwargs)[source]#

Bases: Module

forward(x)[source]#
alphagenome_pytorch.attention.apply_rope(x, positions=None, max_position=8192, inplace=False)[source]#

Applies Rotary Position Embeddings to the input tensor.

Matches JAX: alphagenome_research.model.attention.apply_rope

All computations use the input dtype (x.dtype), matching JAX behavior. When using DtypePolicy.mixed_precision(), this means RoPE computes in bfloat16. When using DtypePolicy.full_float32(), this means RoPE computes in float32.

Parameters:
  • x – Input tensor (B, S, H, C)

  • positions – Optional position indices (B, S)

  • max_position – Maximum position for frequency calculation

  • inplace – If True, use memory-efficient in-place implementation. Reduces memory overhead from ~2x to ~0.5x but modifies x in-place.

Returns:

Tensor with RoPE applied. If inplace=True, returns the same tensor (modified).