Orthrus 4-track
10M-param Mamba RNA encoder, contrastive pretrained on 45M+ mature mRNA transcripts. 4-track one-hot input.
Model zoo
| Repo | Tracks | Embed dim | Objective | Used in |
|---|---|---|---|---|
antichronology/orthrus-4-track |
4 | 512 | contrastive | Nature Methods publication |
antichronology/orthrus-6-track |
6 | 512 | contrastive | Nature Methods publication |
antichronology/orthrus-small-6-track |
6 | 256 | contrastive | Nature Methods publication |
antichronology/orthrus-mlm-6-track |
6 | 512 | contrastive + MLM | Nature Methods publication |
quietflamingo/orthrus-base-4-track |
4 | 256 | contrastive | Pre-publication |
quietflamingo/orthrus-large-4-track |
4 | 512 | contrastive | Pre-publication |
quietflamingo/orthrus-large-6-track |
6 | 512 | contrastive | Pre-publication |
Inference interface
Every Orthrus model exposes the same three inference methods plus a one-hot helper:
| Method | Output shape | Notes |
|---|---|---|
representation(x, lengths, channel_last=True) |
(B, D) |
Mean-pooled, padding-aware |
representation_unpooled(x, channel_last=True) |
(B, L, D) |
Per-position hidden states |
predict_tokens(x, lengths, channel_last=True) |
(B, L, 4) |
MLM logits over [A, C, G, T]. Available on MLM-pretrained repos (*-mlm-*); raises NotImplementedError on contrastive-only checkpoints. |
seq_to_oh(seq) |
(L, 4) |
One-hot helper, ordering [A, C, G, T] (U is treated as T) |
Environment setup
The model uses a Mamba state-space backbone, which requires CUDA. The recommended environment matches the Orthrus GitHub repo:
# Conda env with Python 3.10
mamba create -n orthrus python=3.10
mamba activate orthrus
# PyTorch + transformers + huggingface_hub
pip install 'torch>=2.2' 'transformers<4.46' 'huggingface_hub>=0.24' safetensors
# Mamba kernels (require CUDA; pin versions for the published checkpoints)
pip install causal-conv1d==1.2.0.post2 --no-build-isolation --no-cache-dir
pip install mamba-ssm==1.2.0.post1 --no-build-isolation --no-cache-dir
# GenomeKit, only if you want to build 6-track inputs from real transcripts
mamba install "genomekit>=6.0.0"
wget -O starter_build.sh https://raw.githubusercontent.com/deepgenomics/GenomeKit/main/starter/build.sh
chmod +x starter_build.sh
./starter_build.sh
Load the model
import torch
from transformers import AutoModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(
"antichronology/orthrus-4-track",
trust_remote_code=True,
).to(device).eval()
Pooled representation (4-track input)
The 4-track input is just one-hot-encoded nucleotides:
sequence = (
"TCATCTGGATTATACATATTTCGCAATGAAAGAGAGGAAGAAAAGGAAGCAGCAAAATATGTGGAGGCCCA"
"ACAAAAGAGACTAGAAGCCTTATTCACTAAAATTCAGGAGGAATTTGAAGAACATGAAGTTACTTCCTCC"
)
oh = model.seq_to_oh(sequence).unsqueeze(0).to(device) # (1, L, 4)
lengths = torch.tensor([oh.shape[1]], device=device)
with torch.no_grad():
emb = model.representation(oh, lengths, channel_last=True)
# emb.shape == (1, D)
Pooled representation (6-track input)
For 6-track models, you need two extra channels: CDS (1 at every third base of the coding sequence, 0 elsewhere) and splice (1 at the last base of each exon, 0 elsewhere). The cleanest way to build these from a real transcript is via GenomeKit:
import numpy as np
import torch
from genome_kit import Genome
genome = Genome("gencode.v44") # or whatever annotation you built
def find_transcript_by_gene_name(genome, gene_name):
return [t for t in genome.transcripts if t.gene.name == gene_name]
def get_transcript_seq(transcript, genome):
return "".join(genome.dna(exon) for exon in transcript.exons)
def build_cds_track(transcript):
"""1 at every 3rd base of the CDS, 0 in UTRs."""
exons = transcript.exons
L = sum(len(e) for e in exons)
cds = transcript.cdss
if not cds:
return np.zeros(L, dtype=np.float32)
strand = transcript.strand
if strand == "+":
sorted_cds = sorted(cds, key=lambda c: c.start)
sorted_exons = sorted(exons, key=lambda e: e.start)
first_cds = sorted_cds[0]
else:
sorted_cds = sorted(cds, key=lambda c: c.end, reverse=True)
sorted_exons = sorted(exons, key=lambda e: e.end, reverse=True)
first_cds = sorted_cds[0]
cds_len = sum(len(c) for c in sorted_cds)
five_utr = 0
for ex in sorted_exons:
if strand == "+":
if ex.end <= first_cds.start:
five_utr += len(ex)
elif ex.overlaps(first_cds):
five_utr += max(0, first_cds.start - ex.start)
break
else:
break
else:
if ex.start >= first_cds.end:
five_utr += len(ex)
elif ex.overlaps(first_cds):
five_utr += max(0, ex.end - first_cds.end)
break
else:
break
three_utr = max(0, L - (five_utr + cds_len))
body = np.zeros(cds_len, dtype=np.float32)
body[0::3] = 1.0
return np.concatenate([
np.zeros(five_utr, dtype=np.float32),
body,
np.zeros(three_utr, dtype=np.float32),
])
def build_splice_track(transcript):
"""1 at the last base of each exon, 0 elsewhere."""
exons = transcript.exons
L = sum(len(e) for e in exons)
track = np.zeros(L, dtype=np.float32)
cumulative = 0
for ex in exons:
cumulative += len(ex)
track[cumulative - 1] = 1.0
return track
t = find_transcript_by_gene_name(genome, "BCL2L1")[0]
sequence = get_transcript_seq(t, genome)
cds = build_cds_track(t)
splice = build_splice_track(t)
oh = model.seq_to_oh(sequence).numpy() # (L, 4)
x = np.hstack([oh, cds[:, None], splice[:, None]]) # (L, 6)
x = torch.tensor(x, device=device).unsqueeze(0)
lengths = torch.tensor([x.shape[1]], device=device)
with torch.no_grad():
emb = model.representation(x, lengths, channel_last=True)
# emb.shape == (1, D)
If you already have CDS and splice arrays from another source (e.g. a UCSC genePred table), you can skip GenomeKit and just np.hstack them with seq_to_oh output.
Un-pooled (per-position) representation
with torch.no_grad():
hidden = model.representation_unpooled(x, channel_last=True)
# hidden.shape == (1, L, D)
# Useful for: local scoring at a specific transcript position, attention
# probing, downstream sequence-tagging tasks.
MLM token prediction
Only available on *-mlm-* checkpoints. Calling predict_tokens on a contrastive Orthrus model raises NotImplementedError with a pointer to the MLM repo.
To score the likelihood of a base at a specific transcript position, mask the nucleotide channels (set them to zero) at that position and call predict_tokens:
import torch.nn.functional as F
# Start from a 4-track or 6-track input `x` built above, shape (1, L, C).
pos = 123 # 0-based transcript-coordinate position
x_masked = x.clone()
x_masked[0, pos, :4] = 0.0 # zero nucleotide channels only; keep CDS/splice intact
with torch.no_grad():
logits = model.predict_tokens(x_masked, lengths, channel_last=True) # (1, L, 4)
log_probs = F.log_softmax(logits[0, pos, :], dim=-1)
# log_probs[i] = log P(nucleotide i | masked context), i in [A, C, G, T]
# Variant-effect score for a SNV REF->ALT at this position:
REF, ALT = 0, 2 # e.g. A -> G
llr = (log_probs[ALT] - log_probs[REF]).item()
Fine-tuning
Configuration files and training scripts for fine-tuning, linear probing, and homology-aware splitting live in the Orthrus GitHub repo. All fine-tuning data and pre-computed splits are mirrored on Zenodo.
Citation
@article{fradkinShi2026,
title = {Orthrus: toward evolutionary and functional RNA foundation models},
ISSN = {1548-7105},
url = {http://dx.doi.org/10.1038/s41592-026-03064-3},
DOI = {10.1038/s41592-026-03064-3},
journal = {Nature Methods},
publisher = {Springer Science and Business Media LLC},
author = {Fradkin, Philip and Shi, Ruian "Ian" and Dalal, Taykhoom and Isaev, Keren and Frey, Brendan J. and Lee, Leo J. and Morris, Quaid and Wang, Bo},
year = {2026},
month = Apr
}
License
MIT
- Downloads last month
- 2,973