Named Outputs#
AlphaGenome predicts thousands of genomic tracks — chromatin accessibility, transcription, histone modifications, TF binding, and more. The named outputs API lets you work with these tracks by biological meaning (tissue, assay, ontology) rather than raw channel indices.
Setup#
from alphagenome_pytorch import AlphaGenome
from alphagenome_pytorch.named_outputs import TrackMetadataCatalog
model = AlphaGenome.from_pretrained("weights.pth", device="cuda")
catalog = TrackMetadataCatalog.load_builtin()
model.set_track_metadata_catalog(catalog)
out = model.predict(dna_onehot, organism_index=0, named_outputs=True)
out is a NamedOutputs object. Each output head (out.atac,
out.rna_seq, etc.) is a NamedOutputHead that holds predictions at one
or more resolutions. Index by resolution to get a NamedTrackTensor — a
tensor bundled with per-channel metadata:
out.atac # NamedOutputHead (all resolutions)
out.atac[128] # NamedTrackTensor at 128bp resolution
out.atac[128].tensor # the raw torch.Tensor
out.atac[128].tracks # tuple of TrackMetadata (one per channel)
Output types and track counts#
Output |
Resolutions |
Human tracks |
Mouse tracks |
Raw dimension |
|---|---|---|---|---|
|
1bp, 128bp |
167 |
18 |
256 |
|
1bp, 128bp |
305 |
67 |
384 |
|
1bp, 128bp |
12 |
— |
128 |
|
1bp, 128bp |
546 |
188 |
640 |
|
1bp, 128bp |
667 |
173 |
768 |
|
128bp |
1617 |
127 |
1664 |
|
128bp |
1116 |
183 |
1152 |
|
128bp |
28 |
8 |
28 |
Human/Mouse columns show the number of real (non-padding) tracks. The “Raw dimension” is the full tensor channel count — both organisms share the same dimensions, with padding filling the gap. Named outputs strip padding by default (see Padding tracks).
Track metadata fields#
Each track carries metadata that you can filter on:
Field |
Description |
|---|---|
|
Human-readable identifier (e.g. |
|
|
|
Cell/tissue ontology term (e.g. |
|
Sample name (e.g. |
|
|
|
|
|
Assay description (e.g. |
|
GTEx tissue name (e.g. |
|
Histone modification (e.g. |
|
TF name (e.g. |
|
Data origin (e.g. |
|
Whether the sample was genetically modified |
|
Track mean (used for normalization) |
Access fields directly on TrackMetadata objects:
track = out.atac[128].tracks[0]
track.ontology_curie # Direct attribute access
track.get('biosample_type') # Safe access (returns None if missing)
track.has('genetically_modified') # True if field exists and is not None
track.to_dict() # Serialize to plain dict
Filtering tracks#
The core method is .select(), available on NamedTrackTensor,
NamedOutputHead, and NamedOutputs. It returns a new object with only
the matching tracks — both the tensor and metadata are sliced together, so they
stay in sync.
By metadata field#
# By assay
total_rna = out.rna_seq[1].select(assay_title='total RNA-seq')
total_rna.tensor # already sliced to matching channels
total_rna.tracks # metadata in sync
# By histone mark
h3k27ac = out.chip_histone[128].select(histone_mark='H3K27ac')
# By tissue
aorta = out.splice_junctions[1].select(gtex_tissue='Artery_Aorta')
# By ontology
hepg2_rna = out.rna_seq[1].select(ontology_curie='EFO:0001187')
liver_rna = out.rna_seq[1].select(ontology_curie='UBERON:0002107')
Multiple conditions#
# All kwargs are AND-ed together
ctcf_unmodified = out.chip_tf[128].select(
transcription_factor='CTCF',
genetically_modified=None, # field=None matches missing/null values
)
“Any of” matching#
# Pass a list for OR logic within one field
ctcf_or_foxa1 = out.chip_tf[128].select(
transcription_factor=['CTCF', 'FOXA1']
)
Custom predicate#
# When kwargs aren't expressive enough, use a predicate function
liver_related = out.rna_seq[128].select(
predicate=lambda t: 'liver' in (t.get('biosample_name') or '').lower()
)
Strand filtering#
out.rna_seq[1].select(strand='+') # positive
out.rna_seq[1].select(strand='-') # negative
out.rna_seq[1].select(strand='.') # unstranded
out.rna_seq[1].select(strand=['+', '-']) # stranded (either)
out.rna_seq[1].select(strand=['-', '.']) # non-positive
out.rna_seq[1].select(strand=['+', '.']) # non-negative
Masks and indices#
For loss computation or manual tensor slicing, you can get boolean masks or
integer indices without creating a new NamedTrackTensor:
# Boolean mask — useful for element-wise loss masking
mask = out.rna_seq[128].mask(ontology_curie='UBERON:0002107')
loss = ((preds - targets) ** 2 * mask).mean()
# Integer indices — useful for gather/index_select
indices = out.chip_tf[128].indices(transcription_factor='CTCF')
selected = preds[..., indices]
Resolution-independent queries#
Track metadata (names, ontology, biosample, strand, etc.) is the same at all
resolutions — only the tensor’s sequence dimension differs. You can query
metadata on the NamedOutputHead without choosing a resolution:
head = out.rna_seq
head.num_tracks # same at 1bp and 128bp
head.to_dataframe() # pandas DataFrame, no resolution needed
head.indices(strand='+')
head.mask(strand='+')
Filter at the head level to apply to all resolutions at once:
plus_strand = out.rna_seq.select(strand='+')
plus_strand[1].tensor # 1bp, already filtered
plus_strand[128].tensor # 128bp, already filtered
Both orderings produce identical results:
# These are equivalent
out.rna_seq.select(strand='+')[128].tensor
out.rna_seq[128].select(strand='+').tensor
Cross-head filtering#
Filter across all heads and resolutions at once:
tissue_tracks = out.select(biosample_type='tissue')
for (output_name, resolution), ntt in tissue_tracks.items():
print(f"{output_name}@{resolution}bp: {ntt.num_tracks} tracks")
Variant effect scoring#
Arithmetic operators (+, -, *, /, abs, negation) preserve
metadata from the left operand:
ref = model.predict(ref_onehot, organism_index=0, named_outputs=True)
alt = model.predict(alt_onehot, organism_index=0, named_outputs=True)
# Filter and diff in one expression
chip_diff = (
alt.chip_histone[128].select(strand=['-', '.'])
- ref.chip_histone[128].select(strand=['-', '.'])
)
chip_diff.tensor # the difference tensor
chip_diff.tracks # metadata preserved from alt
Padding tracks#
The raw model tensors include padding channels so that both organisms share the same tensor dimensions. For example, the ATAC head always outputs 256 channels, but only 167 correspond to real human experiments — the remaining 89 are padding placeholders.
Named outputs strip padding by default:
out = model.predict(dna_onehot, organism_index=0, named_outputs=True)
out.atac[128].num_tracks # 167 (real tracks only)
This matches the behavior of the official JAX AlphaGenome API, which also strips padding before exposing metadata to users.
Keeping padding (for training)#
During training you often need the full tensor shape for loss computation with a mask rather than slicing out channels:
out = model.predict(
dna_onehot, organism_index=0,
named_outputs=True,
include_padding=True,
)
out.atac[128].num_tracks # 256 (includes padding)
# Boolean mask: True = real track, False = padding
mask = out.atac[128].padding_mask()
loss = ((preds - targets) ** 2 * mask).mean()
# Resolution-independent mask
mask = out.atac.padding_mask()
Stripping padding after the fact#
out.strip_padding() # All heads → new NamedOutputs
out.atac.strip_padding() # One head → new NamedOutputHead
out.atac[128].strip_padding() # One tensor → new NamedTrackTensor
Checking individual tracks#
for track in out.atac[128].tracks:
print(track.track_name, track.is_padding)
Loading and building metadata catalogs#
The built-in catalog ships with the package and contains metadata extracted from the official JAX AlphaGenome checkpoint:
from alphagenome_pytorch.named_outputs import TrackMetadataCatalog
catalog = TrackMetadataCatalog.load_builtin('human') # or 'mouse'
# Inspect what's available
catalog.outputs(organism=0) # ['atac', 'cage', 'chip_histone', ...]
catalog.organisms # [0]
You can also load from files or build programmatically:
# From parquet / CSV / TSV
catalog = TrackMetadataCatalog.from_file("my_metadata.parquet")
# From a pandas DataFrame
import pandas as pd
df = pd.DataFrame({
'track_name': ['sample_A', 'sample_B'],
'output_type': ['atac', 'atac'],
'organism': [0, 0],
'biosample_type': ['tissue', 'cell_line'],
})
catalog = TrackMetadataCatalog.from_dataframe(df)
# Programmatically
from alphagenome_pytorch.named_outputs import TrackMetadata
catalog = TrackMetadataCatalog()
catalog.add_tracks(
"rna_seq",
[
TrackMetadata(0, "rna_seq", 0, "UBERON:0000948 total RNA-seq",
extras={"strand": "+", "assay_title": "total RNA-seq"}),
TrackMetadata(1, "rna_seq", 0, "UBERON:0000948 total RNA-seq",
extras={"strand": "-", "assay_title": "total RNA-seq"}),
],
organism=0,
)
Exporting to pandas#
df = out.rna_seq[128].to_dataframe() # One head, one resolution
df = out.rna_seq.to_dataframe() # One head (resolution-independent)
Allow empty results#
By default, .select() raises ValueError if no tracks match. Pass
allow_empty=True to get an empty result instead:
result = out.atac[128].select(biosample_name='nonexistent', allow_empty=True)
result.num_tracks # 0
Comparison with JAX AlphaGenome#
This section is for users migrating from the JAX alphagenome /
alphagenome_research packages.
Loading metadata#
JAX |
PyTorch |
|---|---|
model = dna_model.create_from_kaggle('all_folds')
metadata = model.output_metadata(
dna_model.Organism.HOMO_SAPIENS
)
|
catalog = TrackMetadataCatalog.load_builtin('human')
model.set_track_metadata_catalog(catalog)
|
Making predictions#
JAX |
PyTorch |
|---|---|
predictions = model.predict_interval(
interval,
requested_outputs={
dna_model.OutputType.RNA_SEQ,
},
ontology_terms=['EFO:0001187'],
)
predictions.rna_seq.metadata
|
out = model.predict(
dna_onehot, organism_index=0,
named_outputs=True,
)
out.rna_seq[1].to_dataframe()
|
Filtering#
JAX |
PyTorch |
|---|---|
# Boolean mask with pandas
predictions.rna_seq.filter_tracks(
(predictions.rna_seq.metadata[
'Assay title'
] == 'total RNA-seq').values
)
# Multiple conditions
predictions.chip_tf.filter_tracks(
(
(metadata['transcription_factor']
== 'CTCF')
& (metadata[
'genetically_modified'
].isnull())
).values
)
|
# Keyword arguments
out.rna_seq[1].select(
assay_title='total RNA-seq'
)
# Multiple conditions
out.chip_tf[128].select(
transcription_factor='CTCF',
genetically_modified=None,
)
|
Strand filtering#
JAX |
PyTorch |
|---|---|
# 6 dedicated methods
predictions.rna_seq \
.filter_to_positive_strand()
predictions.rna_seq \
.filter_to_nonpositive_strand()
predictions.splice_junctions \
.filter_to_strand('+')
|
# All via select()
out.rna_seq[1].select(strand='+')
out.rna_seq[1].select(
strand=['-', '.']
)
|
Padding#
JAX |
PyTorch |
|---|---|
# Manual boolean mask
padding = metadata.padding
mask = ~padding[OutputType.ATAC]
# Or via create_track_masks()
masks = metadata_lib \
.create_track_masks(
metadata,
requested_outputs={...},
requested_ontologies=None,
)
|
# Stripped by default
out.atac[128].num_tracks # 167
# Or keep + mask
out = model.predict(
dna, 0,
named_outputs=True,
include_padding=True,
)
mask = out.atac.padding_mask()
track.is_padding # per track
|
Feature comparison table#
Feature |
JAX |
PyTorch |
|---|---|---|
Load metadata |
|
|
Filter by field |
|
|
Filter null/missing |
|
|
Access metadata field |
|
|
Safe field access |
— |
|
Check field exists |
— |
|
Strand filtering |
|
|
Tissue filtering |
|
|
Get indices |
manual numpy |
|
Get boolean mask |
manual numpy |
|
Padding detection |
|
|
Strip padding |
|
auto / |
Padding mask |
|
|
To DataFrame |
|
|
Arithmetic |
direct on objects |
direct on objects |
Cross-head filtering |
— |
|
Allow empty results |
— |
|
Load from DataFrame |
— |
|
Design notes#
Why metadata and tensor are bundled (NamedTrackTensor): After any
.select() call, the returned tensor and metadata are guaranteed to be
aligned — no manual index tracking needed.
Why metadata is resolution-independent (NamedOutputHead): Track metadata
doesn’t change between 1bp and 128bp — only the sequence dimension differs.
NamedOutputHead lets you query metadata and filter without choosing a
resolution first.