import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from . import layers
_MAX_RELATIVE_DISTANCE = 8192
def _apply_rope_inplace(x, cos_theta, sin_theta):
"""Memory-efficient in-place RoPE application.
Only allocates 0.5x extra memory (vs 2x in standard implementation).
Modifies x in-place and returns it.
Args:
x: Input tensor (B, S, H, C)
cos_theta: Cosine of rotation angles (B, S, 1, C)
sin_theta: Sine of rotation angles (B, S, 1, C)
Returns:
x modified in-place with RoPE applied
"""
# Clone even positions before overwriting (0.5x memory overhead)
x_even = x[..., ::2].clone()
# Compute and write new even values in-place
# RoPE formula for even indices: x_even * cos - x_odd * sin
x[..., ::2] = x_even * cos_theta[..., ::2] - x[..., 1::2] * sin_theta[..., ::2]
# Compute and write new odd values in-place (uses saved x_even)
# RoPE formula for odd indices: x_even * sin + x_odd * cos
x[..., 1::2] = x_even * sin_theta[..., 1::2] + x[..., 1::2] * cos_theta[..., 1::2]
return x
[docs]
def apply_rope(x, positions=None, max_position=_MAX_RELATIVE_DISTANCE, inplace=False):
"""Applies Rotary Position Embeddings to the input tensor.
Matches JAX: alphagenome_research.model.attention.apply_rope
All computations use the input dtype (x.dtype), matching JAX behavior.
When using DtypePolicy.mixed_precision(), this means RoPE computes in bfloat16.
When using DtypePolicy.full_float32(), this means RoPE computes in float32.
Args:
x: Input tensor (B, S, H, C)
positions: Optional position indices (B, S)
max_position: Maximum position for frequency calculation
inplace: If True, use memory-efficient in-place implementation.
Reduces memory overhead from ~2x to ~0.5x but modifies x in-place.
Returns:
Tensor with RoPE applied. If inplace=True, returns the same tensor (modified).
"""
# x: (B, S, H, C)
B, S, H, C = x.shape
compute_dtype = x.dtype # Match JAX: use input dtype for all RoPE ops
if positions is None:
positions = torch.arange(S, device=x.device, dtype=compute_dtype).unsqueeze(0) # (1, S)
elif positions.dtype != compute_dtype:
positions = positions.to(compute_dtype)
num_freq = C // 2
# JAX geomspace equivalent: geomspace(1, max_position - num_freq + 1, num_freq).
# Use torch.exp(linspace * log(base)) instead of torch.logspace so this stays
# on-device on MPS, which lacks an aten::logspace.out kernel (pytorch/pytorch#141287).
# Compute in float32 and cast the sum (matching JAX's .astype).
log_end = math.log10(max_position - num_freq + 1)
base_freqs = torch.exp(
torch.linspace(0.0, log_end, steps=num_freq,
device=x.device, dtype=torch.float32) * math.log(10.0)
)
denom = (torch.arange(num_freq, device=x.device, dtype=torch.float32) + base_freqs).to(compute_dtype)
inv_freq = 1.0 / denom
theta = torch.einsum('bs,f->bsf', positions, inv_freq)
theta = torch.repeat_interleave(theta, 2, dim=-1).unsqueeze(2) # (B, S, 1, C)
cos_theta = torch.cos(theta)
sin_theta = torch.sin(theta)
if inplace:
return _apply_rope_inplace(x, cos_theta, sin_theta)
else:
x_rotated = torch.stack([-x[..., 1::2], x[..., ::2]], dim=-1).flatten(start_dim=-2)
return x * cos_theta + x_rotated * sin_theta
def _shift(x, query_length, key_length):
"""Shifts the diagonal of a 2D array, PyTorch equivalent."""
# x: (..., query_length, query_length + key_length)
shape = x.shape
batch_shape = shape[:-2]
n_rows = shape[-2]
n_diags = shape[-1]
# Reshape to (..., n_diags, n_rows)
x = x.view(*batch_shape, n_diags, n_rows)
# Drop first row (originally first diag)
x = x[..., 1:, :]
# Reshape back to (..., n_rows, n_diags - 1)
x = x.view(*batch_shape, n_rows, n_diags - 1)
# Return first key_length columns
return x[..., :key_length]
def _central_mask_features(distances, feature_size, seq_length):
"""Positional features using exponentially-spaced central mask.
Matches JAX: alphagenome_research.model.attention._central_mask_features
JAX formula:
center_widths = jnp.arange(feature_size) + jnp.geomspace(
1, seq_length - feature_size + 1, feature_size, endpoint=False
)
"""
device = distances.device
dtype = torch.float32
# Compute geomspace(1, seq_length - feature_size + 1, feature_size, endpoint=False)
# geomspace with endpoint=False: values[i] = start * (end/start)^(i/n)
start = 1.0
end = float(seq_length - feature_size + 1)
log_start = math.log(start)
log_end = math.log(end)
log_step = (log_end - log_start) / feature_size # endpoint=False
exponents = torch.arange(feature_size, device=device, dtype=dtype) * log_step
geomspace_values = torch.exp(torch.tensor(log_start, device=device, dtype=dtype) + exponents)
# JAX: center_widths = jnp.arange(feature_size) + jnp.geomspace(...)
center_widths = torch.arange(feature_size, device=device, dtype=dtype) + geomspace_values
# center_widths: (feature_size,)
# distances: (...)
# Output: (..., feature_size)
return (center_widths > distances.unsqueeze(-1)).to(dtype)
[docs]
class MHABlock(nn.Module):
"""Multi-Head Attention block.
Matches JAX: alphagenome_research.model.attention.MHABlock
JAX uses precision=BF16_BF16_F32 for attention, meaning:
- Inputs in bfloat16
- Accumulation in float32
- Output cast back to input dtype
"""
def __init__(self, d_model):
super().__init__()
self.norm = layers.RMSBatchNorm(d_model, channels_last=True)
self.q_proj = nn.Linear(d_model, 8 * 128, bias=False)
self.norm_q = layers.LayerNorm(128)
self.k_proj = nn.Linear(d_model, 128, bias=False)
self.norm_k = layers.LayerNorm(128)
self.v_proj = nn.Linear(d_model, 192, bias=False)
self.norm_v = layers.LayerNorm(192)
self.final_norm = layers.RMSBatchNorm(d_model, channels_last=True)
self.linear_embedding = nn.Linear(8 * 192, d_model)
# nn.Module allows registering forward hooks to extract attn weights
# e.g.: mha.attn_softmax.register_forward_hook(...)
self.attn_softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x, attention_bias, compute_dtype=None):
B, S, D = x.shape
if compute_dtype is None:
compute_dtype = x.dtype
# Cast to compute dtype
x = x.to(compute_dtype)
h = self.norm(x)
q = self.norm_q(self.q_proj(h).view(B, S, 8, 128))
k = self.norm_k(self.k_proj(h).view(B, S, 1, 128))
v = self.norm_v(self.v_proj(h).view(B, S, 1, 192))
q = apply_rope(q, inplace=True)
k = apply_rope(k, inplace=True)
q_t = q.permute(0, 2, 1, 3) # (B, 8, S, C)
k_t = k.permute(0, 2, 1, 3) # (B, 1, S, C)
# Attention logits: bf16 matmul then cast to f32 (matches JAX BF16_BF16_F32)
# JAX uses precision=BF16_BF16_F32: bf16 inputs, f32 accumulation, f32 output
att = torch.matmul(q_t, k_t.transpose(-2, -1)).float() # (B, 8, S, S)
att = att / math.sqrt(128.0)
if attention_bias is not None:
att = att + attention_bias.float()
logits_soft_cap = 5.0
att = torch.tanh(att / logits_soft_cap) * logits_soft_cap
attn_weights = self.attn_softmax(att)
# Value projection: bf16 matmul then cast back to compute dtype
v_t = v.permute(0, 2, 1, 3)
y = torch.matmul(attn_weights.to(compute_dtype), v_t).float() # (B, 8, S, 192)
y = y.to(compute_dtype)
y = y.permute(0, 2, 1, 3).reshape(B, S, -1)
y = self.linear_embedding(y)
return self.final_norm(y)
[docs]
class MLPBlock(nn.Module):
def __init__(self, d_model):
super().__init__()
self.norm = layers.RMSBatchNorm(d_model, channels_last=True)
self.fc1 = nn.Linear(d_model, d_model * 2)
self.fc2 = nn.Linear(d_model * 2, d_model)
self.final_norm = layers.RMSBatchNorm(d_model, channels_last=True)
[docs]
def forward(self, x):
h = self.norm(x)
h = F.relu(self.fc1(h))
h = self.fc2(h)
return self.final_norm(h)
[docs]
class AttentionBiasBlock(nn.Module):
def __init__(self, pair_dim):
super().__init__()
self.norm = layers.RMSBatchNorm(pair_dim, channels_last=True)
self.proj = nn.Linear(pair_dim, 8, bias=False)
[docs]
def forward(self, x):
# x: (B, s, s, D)
h = F.gelu(self.norm(x))
h = self.proj(h) # (B, s, s, 8)
# Repeat 16x16
h = torch.repeat_interleave(h, 16, dim=1)
h = torch.repeat_interleave(h, 16, dim=2)
return h.permute(0, 3, 1, 2) # (B, 8, S, S)
[docs]
class SequenceToPairBlock(nn.Module):
def __init__(self, d_model, pair_dim=128):
super().__init__()
self.d_model = d_model
# 32 heads * 128 dim = 4096 params for q/k internal?
# JAX uses hardcoded 32*128.
self.num_heads = 32
self.head_dim = 128
self.pool = layers.Pool1d(kernel_size=16, stride=16, method='mean')
self.norm_seq2pair = layers.LayerNorm(d_model, rms_norm=True)
self.linear_q = nn.Linear(d_model, self.num_heads * self.head_dim, bias=False)
self.linear_k = nn.Linear(d_model, self.num_heads * self.head_dim, bias=False)
# Relative positions features -> 2*32 -> ...
self.linear_pos_features = nn.Linear(2 * self.num_heads, self.num_heads * self.head_dim)
self.q_r_bias = nn.Parameter(torch.zeros(1, 1, self.num_heads, self.head_dim))
self.k_r_bias = nn.Parameter(torch.zeros(1, 1, self.num_heads, self.head_dim))
self.linear_y_q = nn.Linear(d_model, self.head_dim, bias=False)
self.linear_y_k = nn.Linear(d_model, self.head_dim, bias=False)
self.linear_pair = nn.Linear(self.num_heads, self.head_dim)
[docs]
def forward(self, x):
# x: (B, S, D) - NLC format
# Pool1d expects NCL, so transpose around pool call
x_pooled = self.pool(x.transpose(1, 2)).transpose(1, 2)
x_norm = self.norm_seq2pair(x_pooled)
B, S_prime, _ = x_norm.shape
q = self.linear_q(x_norm).view(B, S_prime, self.num_heads, self.head_dim)
k = self.linear_k(x_norm).view(B, S_prime, self.num_heads, self.head_dim)
# Relative positions (computed in float32 for precision, then cast to model dtype)
range_vec = torch.arange(-S_prime, S_prime, device=x.device, dtype=torch.float32)
pos_feat = _central_mask_features(torch.abs(range_vec), self.num_heads, _MAX_RELATIVE_DISTANCE // 16)
sign = torch.sign(range_vec).unsqueeze(-1)
pos_feat = torch.cat([pos_feat, sign * pos_feat], dim=-1) # (2S', 64)
pos_feat = pos_feat.to(x.dtype) # Match model dtype
pos_encoding = self.linear_pos_features(pos_feat).view(2 * S_prime, self.num_heads, self.head_dim)
term_q = torch.einsum('bqhc,phc->bhqp', q + self.q_r_bias, pos_encoding)
term_k = torch.einsum('bkhc,phc->bhkp', k + self.k_r_bias, pos_encoding)
rel_q_a = _shift(term_q, S_prime, S_prime)
rel_k_a = _shift(term_k, S_prime, S_prime)
rel_q_a = rel_q_a.permute(0, 2, 3, 1) # (B, S', S', H)
rel_k_a = rel_k_a.permute(0, 3, 2, 1) # (B, S', S', H) from bhkp -> bpkh logic
a = torch.einsum('bqhc,bkhc->bqkh', q, k) # (B, S', S', H)
a = a + 0.5 * (rel_q_a + rel_k_a)
# y branches
x_gelu = F.gelu(x_norm)
y_q = self.linear_y_q(x_gelu)
y_k = self.linear_y_k(x_gelu)
pair_act = self.linear_pair(a) + y_q.unsqueeze(2) + y_k.unsqueeze(1)
return pair_act
[docs]
class RowAttentionBlock(nn.Module):
"""Self-attention block applied along rows of pairwise representations.
Matches JAX: alphagenome_research.model.attention.RowAttentionBlock
JAX uses precision=BF16_BF16_F32 for einsum operations.
"""
def __init__(self, pair_dim=128):
super().__init__()
self.norm = layers.LayerNorm(pair_dim, rms_norm=True)
self.linear_q = nn.Linear(pair_dim, pair_dim, bias=False)
self.linear_k = nn.Linear(pair_dim, pair_dim, bias=False)
self.linear_v = nn.Linear(pair_dim, pair_dim)
# nn.Module allows registering forward hooks to extract attn weights
self.attn_softmax = nn.Softmax(dim=-1)
[docs]
def forward(self, x, compute_dtype=None):
if compute_dtype is None:
compute_dtype = x.dtype
x = x.to(compute_dtype)
h = self.norm(x)
q = self.linear_q(h)
k = self.linear_k(h)
v = self.linear_v(h)
# Attention: bf16 einsum then cast to f32 (matches JAX BF16_BF16_F32)
scale = 1.0 / math.sqrt(128.0)
attn = torch.einsum('bpqf,bpkf->bpqk', q, k).float() * scale
attn = self.attn_softmax(attn)
# Value projection: bf16 einsum then cast back
out = torch.einsum('bpqk,bpkf->bpqf', attn.to(compute_dtype), v).float()
return out.to(compute_dtype)
[docs]
class PairMLPBlock(nn.Module):
def __init__(self, pair_dim=128):
super().__init__()
self.norm = layers.LayerNorm(pair_dim, rms_norm=True)
self.linear1 = nn.Linear(pair_dim, 2 * pair_dim)
self.linear2 = nn.Linear(2 * pair_dim, pair_dim)
[docs]
def forward(self, x):
h = self.norm(x)
h = self.linear1(h)
h = F.relu(h)
h = self.linear2(h)
return h
[docs]
class PairUpdateBlock(nn.Module):
def __init__(self, d_model, pair_dim=128):
super().__init__()
self.seq2pair = SequenceToPairBlock(d_model, pair_dim)
self.row_attn = RowAttentionBlock(pair_dim)
self.pair_mlp = PairMLPBlock(pair_dim)
[docs]
def forward(self, x, pair_rep, compute_dtype=None):
# x: (B, S, D)
# pair_rep: (B, S/16, S/16, F)
y = self.seq2pair(x)
if pair_rep is None:
pair_rep = y
else:
pair_rep = pair_rep + y
pair_rep = pair_rep + self.row_attn(pair_rep, compute_dtype=compute_dtype)
pair_rep = pair_rep + self.pair_mlp(pair_rep)
return pair_rep