import torch
import torch.nn as nn
import torch.nn.functional as F
from . import layers
[docs]
class StandardizedConv1d(nn.Conv1d):
"""1D Convolution with weight standardization and learned scaling.
Expects NCL format (B, C, S).
"""
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding='same', dilation=1, groups=1, bias=True):
super().__init__(in_channels, out_channels, kernel_size, stride, padding=0, dilation=dilation, groups=groups, bias=bias)
# JAX uses padding='SAME'. PyTorch 'same' padding requires specific setup or manual padding.
# We will handle padding in forward to match 'SAME' behavior roughly.
self.pad_mode = padding
# Scale parameter
self.scale = nn.Parameter(torch.ones(out_channels, 1, 1))
[docs]
def forward(self, x):
# x: (B, C, S) - NCL format
# Weight standardization
# JAX: w -= mean(w, axis=(0, 1)) -> (kernel_width, input_channels)
# PyTorch weight: (out_channels, in_channels, kernel_width)
# We want to standardize over (in_channels, kernel_width) corresponding to fan-in?
# JAX shape: (width, input_channels, output_channels). Mean axis (0, 1) means mean over width and input_channels.
# PyTorch equivalent: mean over (1, 2).
w = self.weight
mean = w.mean(dim=(1, 2), keepdim=True)
var = w.var(dim=(1, 2), keepdim=True, unbiased=False)
fan_in = self.in_channels * self.kernel_size[0]
scale_factor = torch.rsqrt(torch.maximum(var * fan_in, torch.tensor(1e-4, device=w.device, dtype=w.dtype))) * self.scale
w_standardized = (w - mean) * scale_factor
# Padding 'SAME' manually if needed, or use functional
# For even kernel sizes, 'same' padding is asymmetric. JAX/TF usually pad more on the right.
if self.pad_mode == 'same':
# Padding formulation:
# Note: this formula is valid for stride=1 only (the only stride used in this model).
pad_total = self.kernel_size[0] - 1
pad_left = pad_total // 2
pad_right = pad_total - pad_left
x = F.pad(x, (pad_left, pad_right))
return F.conv1d(x, w_standardized, self.bias, self.stride, 0, self.dilation, self.groups)
[docs]
class ConvBlock(nn.Module):
"""Convolution block operating on NCL format (B, C, S)."""
def __init__(self, in_channels, out_channels, kernel_size, name=None):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.norm = layers.RMSBatchNorm(in_channels)
if kernel_size == 1:
# Use Conv1d(k=1) instead of Linear - same math, native NCL
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1)
else:
self.conv = StandardizedConv1d(in_channels, out_channels, kernel_size, padding='same')
[docs]
def forward(self, x):
# x: (B, C, S) - NCL format, no transposes needed
return self.conv(layers.gelu(self.norm(x)))
[docs]
class DnaEmbedder(nn.Module):
"""Embeds one-hot DNA to feature space. Expects NCL format (B, 4, S)."""
def __init__(self):
super().__init__()
# JAX: Conv1D(768, 15) -> Input 4 channels (one-hot)
# Then + ConvBlock(768, 5)
self.conv1 = nn.Conv1d(4, 768, kernel_size=15, padding='same')
self.block = ConvBlock(768, 768, kernel_size=5)
[docs]
def forward(self, x):
# x: (B, 4, S) - NCL format, no transposes needed
out = self.conv1(x)
return out + self.block(out)
[docs]
class DownResBlock(nn.Module):
"""Downsampling residual block. Expects NCL format (B, C, S)."""
def __init__(self, in_channels, name=None):
super().__init__()
self.out_channels_int = in_channels + 128
self.block1 = ConvBlock(in_channels, self.out_channels_int, kernel_size=5)
self.block2 = ConvBlock(self.out_channels_int, self.out_channels_int, kernel_size=5)
[docs]
def forward(self, x):
# x: (B, C, S) - NCL format
out = self.block1(x)
# Residual connection with channel padding
# F.pad pads from last dim backwards: (left_S, right_S, left_C, right_C)
# We want to pad channels (dim 1), so: (0, 0, 0, 128)
x_padded = F.pad(x, (0, 0, 0, 128))
out = out + x_padded
return out + self.block2(out)
[docs]
class UpResBlock(nn.Module):
"""Upsampling residual block with skip connection. Expects NCL format (B, C, S)."""
def __init__(self, in_channels, skip_channels):
super().__init__()
self.conv_in = ConvBlock(in_channels, skip_channels, kernel_size=5)
self.residual_scale = nn.Parameter(torch.ones(1))
self.pointwise = ConvBlock(skip_channels, skip_channels, kernel_size=1)
self.conv_out = ConvBlock(skip_channels, skip_channels, kernel_size=5)
[docs]
def forward(self, x, unet_skip):
# x: (B, C, S) - NCL format
# unet_skip: (B, C_skip, S*2) - skip has 2x sequence length
# 1. First block + slice channels to match skip
# Channels are dim 1 in NCL: x[:, :skip_channels, :]
out = self.conv_in(x) + x[:, :unet_skip.shape[1], :]
# 2. Upsample sequence (dim 2 in NCL)
out = torch.repeat_interleave(out, repeats=2, dim=2)
out = out * self.residual_scale
# 3. Add skip connection
out = out + self.pointwise(unet_skip)
# 4. Final block
return out + self.conv_out(out)