Layers#
Core Layers#
Convolutions#
- class alphagenome_pytorch.convolutions.ConvBlock(*args, **kwargs)[source]#
Bases:
ModuleConvolution block operating on NCL format (B, C, S).
- class alphagenome_pytorch.convolutions.DnaEmbedder(*args, **kwargs)[source]#
Bases:
ModuleEmbeds one-hot DNA to feature space. Expects NCL format (B, 4, S).
- class alphagenome_pytorch.convolutions.DownResBlock(*args, **kwargs)[source]#
Bases:
ModuleDownsampling residual block. Expects NCL format (B, C, S).
Attention#
- class alphagenome_pytorch.attention.MHABlock(*args, **kwargs)[source]#
Bases:
ModuleMulti-Head Attention block.
Matches JAX: alphagenome_research.model.attention.MHABlock
JAX uses precision=BF16_BF16_F32 for attention, meaning: - Inputs in bfloat16 - Accumulation in float32 - Output cast back to input dtype
- class alphagenome_pytorch.attention.RowAttentionBlock(*args, **kwargs)[source]#
Bases:
ModuleSelf-attention block applied along rows of pairwise representations.
Matches JAX: alphagenome_research.model.attention.RowAttentionBlock
JAX uses precision=BF16_BF16_F32 for einsum operations.
- alphagenome_pytorch.attention.apply_rope(x, positions=None, max_position=8192, inplace=False)[source]#
Applies Rotary Position Embeddings to the input tensor.
Matches JAX: alphagenome_research.model.attention.apply_rope
All computations use the input dtype (x.dtype), matching JAX behavior. When using DtypePolicy.mixed_precision(), this means RoPE computes in bfloat16. When using DtypePolicy.full_float32(), this means RoPE computes in float32.
- Parameters:
x – Input tensor (B, S, H, C)
positions – Optional position indices (B, S)
max_position – Maximum position for frequency calculation
inplace – If True, use memory-efficient in-place implementation. Reduces memory overhead from ~2x to ~0.5x but modifies x in-place.
- Returns:
Tensor with RoPE applied. If inplace=True, returns the same tensor (modified).