|
|
--- |
|
|
library_name: transformers |
|
|
pipeline_tag: fill-mask |
|
|
tags: |
|
|
- genomics |
|
|
- dna |
|
|
- masked-lm |
|
|
- ntv3 |
|
|
- long-range |
|
|
- multi-species |
|
|
- conditioned |
|
|
- supervised |
|
|
- bigwig-prediction |
|
|
- functional-genomics |
|
|
license: other |
|
|
language: |
|
|
- code |
|
|
model_parameter_count: 622605044 |
|
|
--- |
|
|
|
|
|
<div style="background-color: rgba(255, 68, 68, 0.15); padding: 5px; border: 2px solid #ff4444; border-radius: 3px;"> |
|
|
<h3>⚠️ WARNING: Ablation Models Ahead</h3> |
|
|
<p>This 5-downsample model structure is <strong>experimental</strong> and intended solely for exploration related to the model structure ablation studies.</p> |
|
|
<p><strong>They are NOT the main, recommended NTv3 models for results.</strong></p> |
|
|
</div> |
|
|
|
|
|
## 🧬 NTv3: A Foundation Model for Genomics |
|
|
|
|
|
NTv3 is a series of foundational models designed to understand and generate genomic sequences. It unifies representation learning, functional prediction, and controllable sequence generation within a single, efficient U-Net-like architecture. It also enables the modeling of long-range dependencies, up to 1 Mb of context, at nucleotide resolution. Pretrained on 9 trillion base pairs, NTv3 excels at functional-track prediction and genome annotation across 24 animal and plant species. It can also be fine-tuned into a controllable generative model for genomic sequence design. This is a **post-trained (supervised) multi-species model** that can predict functional genomics tracks (BigWig) and genomic elements (BED) across multiple species. It builds on the pre-trained NTv3 model with additional conditioning mechanisms and task-specific heads. For more details, please refer to the [NTv3 paper placeholder]. |
|
|
|
|
|
## ⚖️ License Summary |
|
|
|
|
|
1. The Licensed Models are **only** available under this License for Non-Commercial Purposes. |
|
|
2. You are permitted to reproduce, publish, share and adapt the Output generated by the Licensed Model only for Non-Commercial Purposes and in accordance with this License. |
|
|
3. You may **not** use the Licensed Models or any of its Outputs in connection with: |
|
|
1. any Commercial Purposes, unless agreed by Us under a separate licence; |
|
|
2. to train, improve or otherwise influence the functionality or performance of any other third-party derivative model that is commercial or intended for a Commercial Purpose and is similar to the Licensed Models; |
|
|
3. to create models distilled or derived from the Outputs of the Licensed Models, unless such models are for Non-Commercial Purposes and open-sourced under the same license as the Licensed Models; or |
|
|
4. in violation of any applicable laws and regulations. |
|
|
|
|
|
## 📋 Model Summary |
|
|
|
|
|
- Architecture: Conditioned U-Net with adaptive layer norms → Transformer stack → multi-species prediction heads |
|
|
- Tokenizer: character-level over A T C G N + specials (`<unk>` `<pad>` `<mask>` `<cls>` `<eos>` `<bos>`) |
|
|
- Condition tokenizer: for species/assembly selection |
|
|
- Selective intermediate outputs: use config to save specific layers |
|
|
- Multi-species: 24 assemblies/species |
|
|
- Outputs: MLM logits + BigWig tracks + BED elements + condition logits |
|
|
- Dependencies: needs transformers >= 4.55.0 |
|
|
- Input size: input sequence length needs to be a multiple of 128 |
|
|
- Note: custom code → use `trust_remote_code=True` |
|
|
|
|
|
## 🚀 Quickstart |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
repo = "InstaDeepAI/NTv3_5downsample_post" |
|
|
tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(repo, trust_remote_code=True) |
|
|
|
|
|
# Prepare inputs |
|
|
batch = tokenizer(["ATCGNATCG", "ACGT"], add_special_tokens=False, padding=True, pad_to_multiple_of=128, return_tensors="pt") |
|
|
|
|
|
# Species tokens |
|
|
species = ['human', 'mouse'] |
|
|
species_ids = model.encode_species(species) |
|
|
|
|
|
# Forward pass |
|
|
out = model( |
|
|
input_ids=batch["input_ids"], |
|
|
species_ids=species_ids, |
|
|
) |
|
|
|
|
|
print(out.logits.shape) # MLM logits: (B, L, V = 11) |
|
|
print(out.bigwig_tracks_logits.shape) # BigWig predictions |
|
|
print(out.bed_tracks_logits.shape) # Bed track predictions |
|
|
``` |
|
|
|
|
|
## 💻 Outputs |
|
|
|
|
|
The model returns `NTv3PostTrainedOutput` (or tuple if `return_dict=False`) with: |
|
|
- `logits`: MLM predictions (B, L, vocab_size) |
|
|
- `bigwig_tracks_logits`: Functional genomics tracks (B, L', num_tracks) - optional |
|
|
- `bed_tracks_logits`: BED element predictions (if available) - optional |
|
|
- `embedding`: Final embedding after deconv tower (B, L', embed_dim) |
|
|
- `after_transformer_embedding`: Embedding after transformer tower (B, L', embed_dim) |
|
|
- `hidden_states`: Tuple of all layer hidden states (if `output_hidden_states=True`) - optional |
|
|
- `attentions`: Tuple of all transformer layer attention weights (if `output_attentions=True`) - optional |
|
|
|
|
|
```python |
|
|
out = model( |
|
|
input_ids=batch["input_ids"], |
|
|
species_ids=species_ids, |
|
|
output_hidden_states=True, |
|
|
output_attentions=True, |
|
|
) |
|
|
|
|
|
# Access outputs |
|
|
logits = out.logits |
|
|
bigwig_tracks = out.bigwig_tracks_logits |
|
|
bed_tracks = out.bed_tracks_logits |
|
|
embedding = out.embedding |
|
|
hidden_states = out.hidden_states # Tuple of all layer embeddings |
|
|
attentions = out.attentions # Tuple of all transformer layer attention weights |
|
|
``` |
|
|
|
|
|
## 🔤 Tokenization |
|
|
|
|
|
```python |
|
|
enc = tokenizer("ATCGNATCG", add_special_tokens=False) |
|
|
print(enc["input_ids"]) # char-level IDs |
|
|
# To show all supported species: |
|
|
print(model.config.species_to_token_id.keys()) |
|
|
# Using Human as an example |
|
|
species_ids = model.encode_species(['human', 'mouse']) |
|
|
print(species_ids.shape) # (B,) |
|
|
print(species_ids) # [27, 29] # (B,) |
|
|
``` |
|
|
|
|
|
## 🛠️ Selective intermediate outputs |
|
|
|
|
|
You can also save specific intermediate outputs with custom keys: |
|
|
|
|
|
```python |
|
|
|
|
|
config = AutoConfig.from_pretrained(repo, trust_remote_code=True) |
|
|
# Save embeddings from specific transformer layers |
|
|
config.embeddings_layers_to_save = (1, 2) |
|
|
# Save attention maps from specific layers/heads |
|
|
config.attention_maps_to_save = [(1, 0), (2, 1)] # (layer, head) |
|
|
# Save embeddings from specific deconv layers |
|
|
config.deconv_layers_to_save = (1, 2) |
|
|
|
|
|
model = AutoModel.from_pretrained(repo, config=config, trust_remote_code=True) |
|
|
# Access via core's output dict (these are saved in addition to hidden_states/attentions) |
|
|
core_out = model.core( |
|
|
input_ids=batch["input_ids"], |
|
|
species_ids=species_ids, |
|
|
output_hidden_states=True, |
|
|
output_attentions=True, |
|
|
) |
|
|
emb_1 = core_out['embeddings_1'] # Transformer layer 1 |
|
|
attn_1_0 = core_out['attention_map_layer_1_number_0'] # Layer 1, head 0 |
|
|
deconv_1 = core_out['embeddings_deconv_1'] # Deconv layer 1 |
|
|
``` |
|
|
|
|
|
## 📊 Shapes & config summary |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Vocab size | 11 | |
|
|
| Token embedding dim | 16 | |
|
|
| Model (hidden) dim | 1536 | |
|
|
| FFN dim | 6144 | |
|
|
| Attention heads | 24 | |
|
|
| Transformer layers | 12 | |
|
|
| Downsample stages | 5 | |
|
|
| Condition dimensions | 1 | |
|
|
| Assemblies/species | 24 | |
|
|
|
|
|
|
|
|
## ⚡ Mixed precision |
|
|
|
|
|
This model was originally trained with mixed precision (bf16) in JAX and later ported to Torch. During JAX training, all weights maintained full fp32 precision at all times, but certain inferences were performed in bf16 for efficiency. This repo will be loaded with full precision (fp32) inference by default to ensure numerical stability. However, it can be used with mixed precision (bf16) for efficient long range training and inferences. Do note, to support bfloat16 precision, you need to use a GPU with bfloat16 support (e.g. A100, H100, etc.). Also, loading the model with mixed precision would introduce numerical instability, including small differences to the original JAX model. The difference is usually insignificant, but be aware of it when using the model. |
|
|
|
|
|
To load the model with mixed precision, use the following code: |
|
|
|
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
repo = "InstaDeepAI/NTv3_5downsample_post" |
|
|
tokenizer = AutoTokenizer.from_pretrained(repo, trust_remote_code=True) |
|
|
model = AutoModel.from_pretrained(repo, trust_remote_code=True, |
|
|
stem_compute_dtype='bfloat16', |
|
|
down_convolution_compute_dtype='bfloat16', |
|
|
transformer_qkvo_compute_dtype='bfloat16', |
|
|
transformer_ffn_compute_dtype='bfloat16', |
|
|
up_convolution_compute_dtype='bfloat16', |
|
|
modulation_compute_dtype='bfloat16', |
|
|
) |
|
|
``` |