Files changed (2) hide show
  1. chatNT.py +36 -4
  2. text_generation.py +8 -13
chatNT.py CHANGED
@@ -28,6 +28,7 @@ class RotaryEmbeddingConfig:
28
  class PerceiverResamplerConfig:
29
  """
30
  Parameters to initialize an PerceiverResampler model.
 
31
  Args:
32
  emb_layer_norm_before: Whether to use layer norm before the first attention
33
  layer.
@@ -92,7 +93,9 @@ class PerceiverResamplerConfig:
92
  class GptConfig:
93
  """
94
  Parameters to initialize a Gpt model.
 
95
  NOTE: the pad token is not defined
 
96
  Args:
97
  vocab_size: Token vocabulary.
98
  eos_token_id: used to stop sentence generation
@@ -188,6 +191,7 @@ class GptConfig:
188
  class NucleotideTransformerConfig:
189
  """
190
  Parameters to initialize an NT model.
 
191
  Args:
192
  alphabet_size: Token vocabulary.
193
  pad_token_id: ID of pad token.
@@ -369,6 +373,7 @@ class TorchBioBrainDecoder(nn.Module):
369
  """
370
  Initializes the BioBrain decoder, using a GPT model for text generation with
371
  bio embeddings.
 
372
  Args:
373
  gpt_config: Configuration for the GPT model
374
  seq_token_id: Index of the SEQ token
@@ -385,11 +390,13 @@ class TorchBioBrainDecoder(nn.Module):
385
  ) -> torch.Tensor:
386
  """
387
  Forward pass through the model.
 
388
  Args:
389
  english_token_ids: Tensor of English token IDs with shape
390
  (batch_size, num_english_tokens).
391
  projected_bio_embeddings: Optional tensor of bio embeddings with shape
392
  (batch_size, num_bio_sequences, ?, embed_dim).
 
393
  Returns:
394
  torch.Tensor: The logits from the GPT model,
395
  shaped (batch_size, num_english_tokens, vocab_size).
@@ -445,11 +452,13 @@ class TorchBioBrainDecoder(nn.Module):
445
  ) -> Tuple[torch.Tensor, torch.Tensor]:
446
  """
447
  Inserts resampled embeddings in input_embeddings, starting at the SEQ token
 
448
  Args:
449
  tokens (torch.Tensor): Shape (batch_size, num_tokens)
450
  input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
451
  resampled_embeddings (torch.Tensor):
452
  Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
 
453
  Returns:
454
  Tuple[torch.Tensor, torch.Tensor]:
455
  - input_embeddings with resampled_embeddings inserted at the SEQ token
@@ -512,9 +521,11 @@ class TorchBioBrainDecoder(nn.Module):
512
  ) -> Tuple[torch.Tensor, torch.Tensor]:
513
  """
514
  Removes the logits corresponding to the unused embeddings.
 
515
  Args:
516
  tokens: Input english tokens.
517
  logits: Input logits.
 
518
  Returns:
519
  Cleaned logits, last values will be equal to 0.
520
  """
@@ -629,34 +640,39 @@ class TorchMultiOmicsModel(PreTrainedModel):
629
 
630
  def forward(
631
  self,
632
- multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor | None],
633
  projection_english_tokens_ids: torch.Tensor,
634
  projected_bio_embeddings: torch.Tensor = None,
635
  ) -> dict[str, torch.Tensor]:
636
  """
 
637
  Args:
638
  multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
639
  english_tokens_ids: Represents the prompt tokens (english tokens)
640
  Shape (batch_size, num_english_tokens)
 
641
  bio_tokens_ids: Represents the bio sequences tokens
642
  Shape (batch_size, num_bio_sequences, num_bio_tokens)
 
643
  projection_english_tokens_ids (torch.Tensor):
644
  Shape (batch_size, num_english_tokens)
 
645
  projected_bio_embeddings (projected_bio_embeddings, optional):
646
  Shape (batch_size, num_bio_sequencse, ?, embed_dim).
647
  Defaults to None.
 
648
  Returns:
649
  dict[str, torch.Tensor] containing:
650
  - logits:
651
  Shape (batch_size, num_tokens, vocab_size)
 
652
  - projected_bio_embeddings:
653
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
654
  """
655
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
656
  english_token_ids = english_token_ids.clone()
 
657
  projection_english_tokens_ids = projection_english_tokens_ids.clone()
658
- if bio_token_ids is not None:
659
- bio_token_ids = bio_token_ids.clone()
660
  if projected_bio_embeddings is not None:
661
  projected_bio_embeddings = projected_bio_embeddings.clone()
662
 
@@ -724,6 +740,7 @@ class TorchRotaryEmbedding(torch.nn.Module):
724
  def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
725
  """
726
  Create the sines and cosines for the RoPE.
 
727
  Returns:
728
  Sinusoidal positions of shape (self.max_seq_len, self.dim).
729
  """
@@ -756,9 +773,11 @@ class TorchRotaryEmbedding(torch.nn.Module):
756
  def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
757
  """
758
  Prepare a tensor to apply the RoPE mechanism.
 
759
  Args:
760
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
761
  typically this is the key or query tensor.
 
762
  Returns:
763
  The even indices in the last dimension have their sign flipped.
764
  Tensor of shape (batch_size, seq_len, num_heads, head_dim).
@@ -775,10 +794,12 @@ class TorchRotaryEmbedding(torch.nn.Module):
775
  ) -> torch.Tensor:
776
  """
777
  Applies rotary embeddings to x.
 
778
  Args:
779
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
780
  typically this is the key or query tensor.
781
  sincos: Tuple of sine and cosine tensors for position encoding.
 
782
  Returns:
783
  RoPE embeddings tensor.
784
  """
@@ -796,10 +817,12 @@ class TorchRotaryEmbedding(torch.nn.Module):
796
  ) -> tuple[torch.Tensor, torch.Tensor]:
797
  """
798
  Applies rotary embeddings to k and q.
 
799
  Args:
800
  k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
801
  q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
802
  positions: optional positions offset useful when caching,
 
803
  Returns:
804
  RoPE embeddings for the keys and values.
805
  """
@@ -1117,9 +1140,11 @@ def build_causal_attention_mask(
1117
  """
1118
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1119
  to an attention layer.
 
1120
  Args:
1121
  batch_size: Batch size.
1122
  seq_len: Length of the sequences.
 
1123
  Returns:
1124
  Batch of causal masks.
1125
  """
@@ -1525,11 +1550,13 @@ class TorchNucleotideTransformer(nn.Module):
1525
  ) -> torch.Tensor:
1526
  """
1527
  Computes the embeddings based on the input tokens.
 
1528
  Args:
1529
  tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1530
  attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1531
  If no mask is provided, a mask by default which equals 1 over all non
1532
  pad tokens and 0 over pad tokens is computed.
 
1533
  Returns:
1534
  Dictionary containing the final embeddings and logits.
1535
  """
@@ -1557,9 +1584,11 @@ def build_padding_attention_mask(
1557
  ) -> torch.Tensor:
1558
  """
1559
  Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
 
1560
  Args:
1561
  tokens: Batch of sequences of shape (batch_size, seq_len).
1562
  pad_token_id: Int corresponding to the <pad> token to mask.
 
1563
  Returns:
1564
  Batch of attention masks, masking out <pad> tokens.
1565
  """
@@ -1586,6 +1615,7 @@ class TorchBioBrainEncoder(nn.Module):
1586
  Args:
1587
  bio_token_ids (torch.Tensor):
1588
  Shape (batch_size, num_bio_tokens)
 
1589
  Returns:
1590
  torch.Tensor:
1591
  Shape (batch_size, num_bio_tokens, embed_dim)
@@ -1695,6 +1725,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
1695
  ):
1696
  """
1697
  Initialize a Perceiver Resampler model.
 
1698
  Args:
1699
  config: Dataclass containing model hyperparameters.
1700
  name: Name for module (custom will break weight loading).
@@ -1823,8 +1854,10 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
1823
  Args:
1824
  bio_token_ids (torch.Tensor):
1825
  Shape (batch_size, num_bio_tokens)
 
1826
  bio_embeddings (torch.Tensor):
1827
  Shape (batch_size, num_bio_tokens, embed_dim)
 
1828
  english_token_ids (torch.Tensor):
1829
  Shape (batch_size, num_english_tokens)
1830
  """
@@ -1867,4 +1900,3 @@ def build_perceiver_padding_attention_mask(
1867
  padding_mask = padding_mask[:, None, None, :]
1868
  padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1869
  return padding_mask
1870
-
 
28
  class PerceiverResamplerConfig:
29
  """
30
  Parameters to initialize an PerceiverResampler model.
31
+
32
  Args:
33
  emb_layer_norm_before: Whether to use layer norm before the first attention
34
  layer.
 
93
  class GptConfig:
94
  """
95
  Parameters to initialize a Gpt model.
96
+
97
  NOTE: the pad token is not defined
98
+
99
  Args:
100
  vocab_size: Token vocabulary.
101
  eos_token_id: used to stop sentence generation
 
191
  class NucleotideTransformerConfig:
192
  """
193
  Parameters to initialize an NT model.
194
+
195
  Args:
196
  alphabet_size: Token vocabulary.
197
  pad_token_id: ID of pad token.
 
373
  """
374
  Initializes the BioBrain decoder, using a GPT model for text generation with
375
  bio embeddings.
376
+
377
  Args:
378
  gpt_config: Configuration for the GPT model
379
  seq_token_id: Index of the SEQ token
 
390
  ) -> torch.Tensor:
391
  """
392
  Forward pass through the model.
393
+
394
  Args:
395
  english_token_ids: Tensor of English token IDs with shape
396
  (batch_size, num_english_tokens).
397
  projected_bio_embeddings: Optional tensor of bio embeddings with shape
398
  (batch_size, num_bio_sequences, ?, embed_dim).
399
+
400
  Returns:
401
  torch.Tensor: The logits from the GPT model,
402
  shaped (batch_size, num_english_tokens, vocab_size).
 
452
  ) -> Tuple[torch.Tensor, torch.Tensor]:
453
  """
454
  Inserts resampled embeddings in input_embeddings, starting at the SEQ token
455
+
456
  Args:
457
  tokens (torch.Tensor): Shape (batch_size, num_tokens)
458
  input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
459
  resampled_embeddings (torch.Tensor):
460
  Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
461
+
462
  Returns:
463
  Tuple[torch.Tensor, torch.Tensor]:
464
  - input_embeddings with resampled_embeddings inserted at the SEQ token
 
521
  ) -> Tuple[torch.Tensor, torch.Tensor]:
522
  """
523
  Removes the logits corresponding to the unused embeddings.
524
+
525
  Args:
526
  tokens: Input english tokens.
527
  logits: Input logits.
528
+
529
  Returns:
530
  Cleaned logits, last values will be equal to 0.
531
  """
 
640
 
641
  def forward(
642
  self,
643
+ multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
644
  projection_english_tokens_ids: torch.Tensor,
645
  projected_bio_embeddings: torch.Tensor = None,
646
  ) -> dict[str, torch.Tensor]:
647
  """
648
+
649
  Args:
650
  multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
651
  english_tokens_ids: Represents the prompt tokens (english tokens)
652
  Shape (batch_size, num_english_tokens)
653
+
654
  bio_tokens_ids: Represents the bio sequences tokens
655
  Shape (batch_size, num_bio_sequences, num_bio_tokens)
656
+
657
  projection_english_tokens_ids (torch.Tensor):
658
  Shape (batch_size, num_english_tokens)
659
+
660
  projected_bio_embeddings (projected_bio_embeddings, optional):
661
  Shape (batch_size, num_bio_sequencse, ?, embed_dim).
662
  Defaults to None.
663
+
664
  Returns:
665
  dict[str, torch.Tensor] containing:
666
  - logits:
667
  Shape (batch_size, num_tokens, vocab_size)
668
+
669
  - projected_bio_embeddings:
670
  Shape (batch_size, num_bio_sequences, ?, embed_dim)
671
  """
672
  english_token_ids, bio_token_ids = multi_omics_tokens_ids
673
  english_token_ids = english_token_ids.clone()
674
+ bio_token_ids = bio_token_ids.clone()
675
  projection_english_tokens_ids = projection_english_tokens_ids.clone()
 
 
676
  if projected_bio_embeddings is not None:
677
  projected_bio_embeddings = projected_bio_embeddings.clone()
678
 
 
740
  def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
741
  """
742
  Create the sines and cosines for the RoPE.
743
+
744
  Returns:
745
  Sinusoidal positions of shape (self.max_seq_len, self.dim).
746
  """
 
773
  def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
774
  """
775
  Prepare a tensor to apply the RoPE mechanism.
776
+
777
  Args:
778
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
779
  typically this is the key or query tensor.
780
+
781
  Returns:
782
  The even indices in the last dimension have their sign flipped.
783
  Tensor of shape (batch_size, seq_len, num_heads, head_dim).
 
794
  ) -> torch.Tensor:
795
  """
796
  Applies rotary embeddings to x.
797
+
798
  Args:
799
  x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
800
  typically this is the key or query tensor.
801
  sincos: Tuple of sine and cosine tensors for position encoding.
802
+
803
  Returns:
804
  RoPE embeddings tensor.
805
  """
 
817
  ) -> tuple[torch.Tensor, torch.Tensor]:
818
  """
819
  Applies rotary embeddings to k and q.
820
+
821
  Args:
822
  k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
823
  q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
824
  positions: optional positions offset useful when caching,
825
+
826
  Returns:
827
  RoPE embeddings for the keys and values.
828
  """
 
1140
  """
1141
  Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
1142
  to an attention layer.
1143
+
1144
  Args:
1145
  batch_size: Batch size.
1146
  seq_len: Length of the sequences.
1147
+
1148
  Returns:
1149
  Batch of causal masks.
1150
  """
 
1550
  ) -> torch.Tensor:
1551
  """
1552
  Computes the embeddings based on the input tokens.
1553
+
1554
  Args:
1555
  tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
1556
  attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
1557
  If no mask is provided, a mask by default which equals 1 over all non
1558
  pad tokens and 0 over pad tokens is computed.
1559
+
1560
  Returns:
1561
  Dictionary containing the final embeddings and logits.
1562
  """
 
1584
  ) -> torch.Tensor:
1585
  """
1586
  Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
1587
+
1588
  Args:
1589
  tokens: Batch of sequences of shape (batch_size, seq_len).
1590
  pad_token_id: Int corresponding to the <pad> token to mask.
1591
+
1592
  Returns:
1593
  Batch of attention masks, masking out <pad> tokens.
1594
  """
 
1615
  Args:
1616
  bio_token_ids (torch.Tensor):
1617
  Shape (batch_size, num_bio_tokens)
1618
+
1619
  Returns:
1620
  torch.Tensor:
1621
  Shape (batch_size, num_bio_tokens, embed_dim)
 
1725
  ):
1726
  """
1727
  Initialize a Perceiver Resampler model.
1728
+
1729
  Args:
1730
  config: Dataclass containing model hyperparameters.
1731
  name: Name for module (custom will break weight loading).
 
1854
  Args:
1855
  bio_token_ids (torch.Tensor):
1856
  Shape (batch_size, num_bio_tokens)
1857
+
1858
  bio_embeddings (torch.Tensor):
1859
  Shape (batch_size, num_bio_tokens, embed_dim)
1860
+
1861
  english_token_ids (torch.Tensor):
1862
  Shape (batch_size, num_english_tokens)
1863
  """
 
1900
  padding_mask = padding_mask[:, None, None, :]
1901
  padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
1902
  return padding_mask
 
text_generation.py CHANGED
@@ -55,24 +55,19 @@ class TextGenerationPipeline(Pipeline):
55
  truncation=True,
56
  max_length=english_tokens_max_length,
57
  ).input_ids
58
- if len(dna_sequences) == 0:
59
- bio_tokens = None
60
- else:
61
- bio_tokens = self.bio_tokenizer(
62
- dna_sequences,
63
- return_tensors="pt",
64
- padding="max_length",
65
- max_length=bio_tokens_max_length,
66
- truncation=True,
67
- ).input_ids.unsqueeze(0)
68
 
69
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
70
 
71
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
72
  english_tokens = model_inputs["english_tokens"].clone()
73
- bio_tokens = model_inputs["bio_tokens"]
74
- if bio_tokens is not None:
75
- bio_tokens = bio_tokens.clone()
76
  projected_bio_embeddings = None
77
 
78
  actual_num_steps = 0
 
55
  truncation=True,
56
  max_length=english_tokens_max_length,
57
  ).input_ids
58
+ bio_tokens = self.bio_tokenizer(
59
+ dna_sequences,
60
+ return_tensors="pt",
61
+ padding="max_length",
62
+ max_length=bio_tokens_max_length,
63
+ truncation=True,
64
+ ).input_ids.unsqueeze(0)
 
 
 
65
 
66
  return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
67
 
68
  def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
69
  english_tokens = model_inputs["english_tokens"].clone()
70
+ bio_tokens = model_inputs["bio_tokens"].clone()
 
 
71
  projected_bio_embeddings = None
72
 
73
  actual_num_steps = 0