You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

AuriStream-1B

📚 Paper - 🌐 Project Page

AuriStream is a biologically-inspired, GPT-style autoregressive Transformer trained to predict tokens from the speech stream (denoted as cochlear tokens). These cochlear tokens are discrete codes produced by a companion “WavCoch” tokenizer (a model trained to predict the time-frequency cochleagram from a waveform, with a LFQ bottleneck for token read-out). AuriStream utilizes a long context window of (approx. 20 s, 4096 tokens) and is trained on LibriLight (~60k hours) for 500k steps. It learns meaningful representations about e.g. phoneme/word identity and can predict future tokens to generate speech continuations. Inputs are cochlear token IDs; use it with a WavCoch tokenizer for audio -> tokens.


Installation

pip install -U torch torchaudio transformers

This model uses custom code; when loading from Hugging Face, pass trust_remote_code=True.


Use case 1) Get hidden state embeddings for an audio waveform

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# 1) Load the WavCoch tokenizer (audio -> token IDs)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# 2) Load the AuriStream LM (tokens -> hidden states / next-token prediction)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True
).to(device).eval()

# 3) Read an audio file (mono, 16 kHz recommended)
wav, sr = torchaudio.load("sample.wav")

if wav.size(0) > 1:  # stereo -> mono
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000

# 4) Quantize the audio to obtain cochlear token IDs
with torch.no_grad():
    # The quantizer forward method expects (B, 1, T); returns (B, L)
    token_ids = quantizer(wav.unsqueeze(0).to(device))['input_ids']  # (1, L)

# 5) Forward pass to obtain hidden states
with torch.no_grad():
    out = lm(token_ids, output_hidden_states=True)
    last_layer = out["hidden_states"][-1]   # (1, T, D)
    last_layer_mean = last_layer.mean(dim=1)  # time mean-pool -> (1, D)

print("Mean-pooled embedding shape:", last_layer_mean.shape)

Notes

  • output_hidden_states=True returns all layers.
  • For phoneme/word segments, slice the time axis before pooling.

Use case 2) Generate a speech continuation (cochlear token prediction)

import torch, torchaudio
from transformers import AutoModel

device = "cuda" if torch.cuda.is_available() else "cpu"

# WavCoch tokenizer (audio -> tokens)
quantizer = AutoModel.from_pretrained(
    "TuKoResearch/WavCochV8192", trust_remote_code=True
).to(device).eval()

# AuriStream LM (tokens -> next tokens)
lm = AutoModel.from_pretrained(
    "TuKoResearch/AuriStream1B_librilight_ckpt500k", trust_remote_code=True
).to(device).eval()

# Load and prep a short prompt (e.g., 3s of audio at 16 kHz)
prompt_seconds = 3
wav, sr = torchaudio.load("prompt.wav")
if wav.size(0) > 1:
    wav = wav.mean(dim=0, keepdim=True)
if sr != 16_000:
    wav = torchaudio.transforms.Resample(sr, 16_000)(wav)
    sr = 16_000
# Slice using an integer number of samples
n_samples = int(round(sr * prompt_seconds))
wav = wav[:, :n_samples]

# Quantize the prompt audio to get token IDs
with torch.no_grad():
    prompt_tokens = quantizer(wav.unsqueeze(0).to(device))['input_ids']  # (1, L)

# Decide how many future tokens to generate ("roll-out")
tokens_per_sec = prompt_tokens.size(1) / float(prompt_seconds)
rollout_seconds = 2
rollout_steps = int(round(tokens_per_sec * rollout_seconds)) # K

# Generate future tokens
with torch.no_grad():
    # returns (pred_tokens, pred_logits); temperature/top_k/top_p/seed optional
    pred_tokens, _ = lm.generate(
        prompt_tokens, rollout_steps, temp=0.7, top_k=50, top_p=0.95, seed=0
    )
    full_tokens = torch.cat([prompt_tokens, pred_tokens], dim=1)  # (1, L+K)

Architecture overview

Schematic of the WavCoch tokenizer (panel A) and the AuriStream model (panel B).

Citation

If you use this model, please cite:

@inproceedings{tuckute2025cochleartokens,
  title     = {Representing Speech Through Autoregressive Prediction of Cochlear Tokens},
  author    = {Greta Tuckute and Klemen Kotar and Evelina Fedorenko and Daniel Yamins},
  booktitle = {Interspeech 2025},
  year      = {2025},
  pages     = {2180--2184},
  doi       = {10.21437/Interspeech.2025-2044},
  issn      = {2958-1796}
}
Downloads last month
18
Safetensors
Model size
1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for TuKoResearch/AuriStream1B_librilight_ckpt500k