chatnt-custom
#2
by
kj03
- opened
- chatNT.py +36 -4
- text_generation.py +8 -13
chatNT.py
CHANGED
|
@@ -28,6 +28,7 @@ class RotaryEmbeddingConfig:
|
|
| 28 |
class PerceiverResamplerConfig:
|
| 29 |
"""
|
| 30 |
Parameters to initialize an PerceiverResampler model.
|
|
|
|
| 31 |
Args:
|
| 32 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
| 33 |
layer.
|
|
@@ -92,7 +93,9 @@ class PerceiverResamplerConfig:
|
|
| 92 |
class GptConfig:
|
| 93 |
"""
|
| 94 |
Parameters to initialize a Gpt model.
|
|
|
|
| 95 |
NOTE: the pad token is not defined
|
|
|
|
| 96 |
Args:
|
| 97 |
vocab_size: Token vocabulary.
|
| 98 |
eos_token_id: used to stop sentence generation
|
|
@@ -188,6 +191,7 @@ class GptConfig:
|
|
| 188 |
class NucleotideTransformerConfig:
|
| 189 |
"""
|
| 190 |
Parameters to initialize an NT model.
|
|
|
|
| 191 |
Args:
|
| 192 |
alphabet_size: Token vocabulary.
|
| 193 |
pad_token_id: ID of pad token.
|
|
@@ -369,6 +373,7 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 369 |
"""
|
| 370 |
Initializes the BioBrain decoder, using a GPT model for text generation with
|
| 371 |
bio embeddings.
|
|
|
|
| 372 |
Args:
|
| 373 |
gpt_config: Configuration for the GPT model
|
| 374 |
seq_token_id: Index of the SEQ token
|
|
@@ -385,11 +390,13 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 385 |
) -> torch.Tensor:
|
| 386 |
"""
|
| 387 |
Forward pass through the model.
|
|
|
|
| 388 |
Args:
|
| 389 |
english_token_ids: Tensor of English token IDs with shape
|
| 390 |
(batch_size, num_english_tokens).
|
| 391 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
| 392 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
|
|
|
| 393 |
Returns:
|
| 394 |
torch.Tensor: The logits from the GPT model,
|
| 395 |
shaped (batch_size, num_english_tokens, vocab_size).
|
|
@@ -445,11 +452,13 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 445 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 446 |
"""
|
| 447 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
|
|
|
| 448 |
Args:
|
| 449 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
| 450 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
| 451 |
resampled_embeddings (torch.Tensor):
|
| 452 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
|
|
|
| 453 |
Returns:
|
| 454 |
Tuple[torch.Tensor, torch.Tensor]:
|
| 455 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
|
@@ -512,9 +521,11 @@ class TorchBioBrainDecoder(nn.Module):
|
|
| 512 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 513 |
"""
|
| 514 |
Removes the logits corresponding to the unused embeddings.
|
|
|
|
| 515 |
Args:
|
| 516 |
tokens: Input english tokens.
|
| 517 |
logits: Input logits.
|
|
|
|
| 518 |
Returns:
|
| 519 |
Cleaned logits, last values will be equal to 0.
|
| 520 |
"""
|
|
@@ -629,34 +640,39 @@ class TorchMultiOmicsModel(PreTrainedModel):
|
|
| 629 |
|
| 630 |
def forward(
|
| 631 |
self,
|
| 632 |
-
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor
|
| 633 |
projection_english_tokens_ids: torch.Tensor,
|
| 634 |
projected_bio_embeddings: torch.Tensor = None,
|
| 635 |
) -> dict[str, torch.Tensor]:
|
| 636 |
"""
|
|
|
|
| 637 |
Args:
|
| 638 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
| 639 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
| 640 |
Shape (batch_size, num_english_tokens)
|
|
|
|
| 641 |
bio_tokens_ids: Represents the bio sequences tokens
|
| 642 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
|
|
|
| 643 |
projection_english_tokens_ids (torch.Tensor):
|
| 644 |
Shape (batch_size, num_english_tokens)
|
|
|
|
| 645 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
| 646 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
| 647 |
Defaults to None.
|
|
|
|
| 648 |
Returns:
|
| 649 |
dict[str, torch.Tensor] containing:
|
| 650 |
- logits:
|
| 651 |
Shape (batch_size, num_tokens, vocab_size)
|
|
|
|
| 652 |
- projected_bio_embeddings:
|
| 653 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 654 |
"""
|
| 655 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
| 656 |
english_token_ids = english_token_ids.clone()
|
|
|
|
| 657 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
| 658 |
-
if bio_token_ids is not None:
|
| 659 |
-
bio_token_ids = bio_token_ids.clone()
|
| 660 |
if projected_bio_embeddings is not None:
|
| 661 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
| 662 |
|
|
@@ -724,6 +740,7 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 724 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
| 725 |
"""
|
| 726 |
Create the sines and cosines for the RoPE.
|
|
|
|
| 727 |
Returns:
|
| 728 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
| 729 |
"""
|
|
@@ -756,9 +773,11 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 756 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
| 757 |
"""
|
| 758 |
Prepare a tensor to apply the RoPE mechanism.
|
|
|
|
| 759 |
Args:
|
| 760 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 761 |
typically this is the key or query tensor.
|
|
|
|
| 762 |
Returns:
|
| 763 |
The even indices in the last dimension have their sign flipped.
|
| 764 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
|
@@ -775,10 +794,12 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 775 |
) -> torch.Tensor:
|
| 776 |
"""
|
| 777 |
Applies rotary embeddings to x.
|
|
|
|
| 778 |
Args:
|
| 779 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 780 |
typically this is the key or query tensor.
|
| 781 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
|
|
|
| 782 |
Returns:
|
| 783 |
RoPE embeddings tensor.
|
| 784 |
"""
|
|
@@ -796,10 +817,12 @@ class TorchRotaryEmbedding(torch.nn.Module):
|
|
| 796 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 797 |
"""
|
| 798 |
Applies rotary embeddings to k and q.
|
|
|
|
| 799 |
Args:
|
| 800 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 801 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 802 |
positions: optional positions offset useful when caching,
|
|
|
|
| 803 |
Returns:
|
| 804 |
RoPE embeddings for the keys and values.
|
| 805 |
"""
|
|
@@ -1117,9 +1140,11 @@ def build_causal_attention_mask(
|
|
| 1117 |
"""
|
| 1118 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1119 |
to an attention layer.
|
|
|
|
| 1120 |
Args:
|
| 1121 |
batch_size: Batch size.
|
| 1122 |
seq_len: Length of the sequences.
|
|
|
|
| 1123 |
Returns:
|
| 1124 |
Batch of causal masks.
|
| 1125 |
"""
|
|
@@ -1525,11 +1550,13 @@ class TorchNucleotideTransformer(nn.Module):
|
|
| 1525 |
) -> torch.Tensor:
|
| 1526 |
"""
|
| 1527 |
Computes the embeddings based on the input tokens.
|
|
|
|
| 1528 |
Args:
|
| 1529 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
| 1530 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
| 1531 |
If no mask is provided, a mask by default which equals 1 over all non
|
| 1532 |
pad tokens and 0 over pad tokens is computed.
|
|
|
|
| 1533 |
Returns:
|
| 1534 |
Dictionary containing the final embeddings and logits.
|
| 1535 |
"""
|
|
@@ -1557,9 +1584,11 @@ def build_padding_attention_mask(
|
|
| 1557 |
) -> torch.Tensor:
|
| 1558 |
"""
|
| 1559 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
|
|
|
| 1560 |
Args:
|
| 1561 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
| 1562 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
|
|
|
| 1563 |
Returns:
|
| 1564 |
Batch of attention masks, masking out <pad> tokens.
|
| 1565 |
"""
|
|
@@ -1586,6 +1615,7 @@ class TorchBioBrainEncoder(nn.Module):
|
|
| 1586 |
Args:
|
| 1587 |
bio_token_ids (torch.Tensor):
|
| 1588 |
Shape (batch_size, num_bio_tokens)
|
|
|
|
| 1589 |
Returns:
|
| 1590 |
torch.Tensor:
|
| 1591 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
@@ -1695,6 +1725,7 @@ class TorchMultiModalPerceiverResampler(nn.Module):
|
|
| 1695 |
):
|
| 1696 |
"""
|
| 1697 |
Initialize a Perceiver Resampler model.
|
|
|
|
| 1698 |
Args:
|
| 1699 |
config: Dataclass containing model hyperparameters.
|
| 1700 |
name: Name for module (custom will break weight loading).
|
|
@@ -1823,8 +1854,10 @@ class TorchMultiModalPerceiverResamplerProjection(nn.Module):
|
|
| 1823 |
Args:
|
| 1824 |
bio_token_ids (torch.Tensor):
|
| 1825 |
Shape (batch_size, num_bio_tokens)
|
|
|
|
| 1826 |
bio_embeddings (torch.Tensor):
|
| 1827 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
|
| 1828 |
english_token_ids (torch.Tensor):
|
| 1829 |
Shape (batch_size, num_english_tokens)
|
| 1830 |
"""
|
|
@@ -1867,4 +1900,3 @@ def build_perceiver_padding_attention_mask(
|
|
| 1867 |
padding_mask = padding_mask[:, None, None, :]
|
| 1868 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
| 1869 |
return padding_mask
|
| 1870 |
-
|
|
|
|
| 28 |
class PerceiverResamplerConfig:
|
| 29 |
"""
|
| 30 |
Parameters to initialize an PerceiverResampler model.
|
| 31 |
+
|
| 32 |
Args:
|
| 33 |
emb_layer_norm_before: Whether to use layer norm before the first attention
|
| 34 |
layer.
|
|
|
|
| 93 |
class GptConfig:
|
| 94 |
"""
|
| 95 |
Parameters to initialize a Gpt model.
|
| 96 |
+
|
| 97 |
NOTE: the pad token is not defined
|
| 98 |
+
|
| 99 |
Args:
|
| 100 |
vocab_size: Token vocabulary.
|
| 101 |
eos_token_id: used to stop sentence generation
|
|
|
|
| 191 |
class NucleotideTransformerConfig:
|
| 192 |
"""
|
| 193 |
Parameters to initialize an NT model.
|
| 194 |
+
|
| 195 |
Args:
|
| 196 |
alphabet_size: Token vocabulary.
|
| 197 |
pad_token_id: ID of pad token.
|
|
|
|
| 373 |
"""
|
| 374 |
Initializes the BioBrain decoder, using a GPT model for text generation with
|
| 375 |
bio embeddings.
|
| 376 |
+
|
| 377 |
Args:
|
| 378 |
gpt_config: Configuration for the GPT model
|
| 379 |
seq_token_id: Index of the SEQ token
|
|
|
|
| 390 |
) -> torch.Tensor:
|
| 391 |
"""
|
| 392 |
Forward pass through the model.
|
| 393 |
+
|
| 394 |
Args:
|
| 395 |
english_token_ids: Tensor of English token IDs with shape
|
| 396 |
(batch_size, num_english_tokens).
|
| 397 |
projected_bio_embeddings: Optional tensor of bio embeddings with shape
|
| 398 |
(batch_size, num_bio_sequences, ?, embed_dim).
|
| 399 |
+
|
| 400 |
Returns:
|
| 401 |
torch.Tensor: The logits from the GPT model,
|
| 402 |
shaped (batch_size, num_english_tokens, vocab_size).
|
|
|
|
| 452 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 453 |
"""
|
| 454 |
Inserts resampled embeddings in input_embeddings, starting at the SEQ token
|
| 455 |
+
|
| 456 |
Args:
|
| 457 |
tokens (torch.Tensor): Shape (batch_size, num_tokens)
|
| 458 |
input_embeddings (torch.Tensor): Shape (batch_size, num_tokens, embed_dim)
|
| 459 |
resampled_embeddings (torch.Tensor):
|
| 460 |
Shape (batch_size, num_bio_sequences, bio_sequence_length, embed_dim)
|
| 461 |
+
|
| 462 |
Returns:
|
| 463 |
Tuple[torch.Tensor, torch.Tensor]:
|
| 464 |
- input_embeddings with resampled_embeddings inserted at the SEQ token
|
|
|
|
| 521 |
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 522 |
"""
|
| 523 |
Removes the logits corresponding to the unused embeddings.
|
| 524 |
+
|
| 525 |
Args:
|
| 526 |
tokens: Input english tokens.
|
| 527 |
logits: Input logits.
|
| 528 |
+
|
| 529 |
Returns:
|
| 530 |
Cleaned logits, last values will be equal to 0.
|
| 531 |
"""
|
|
|
|
| 640 |
|
| 641 |
def forward(
|
| 642 |
self,
|
| 643 |
+
multi_omics_tokens_ids: tuple[torch.Tensor, torch.Tensor],
|
| 644 |
projection_english_tokens_ids: torch.Tensor,
|
| 645 |
projected_bio_embeddings: torch.Tensor = None,
|
| 646 |
) -> dict[str, torch.Tensor]:
|
| 647 |
"""
|
| 648 |
+
|
| 649 |
Args:
|
| 650 |
multi_omics_tokens_ids (Tuple[torch.Tensor, torch.Tensor]):
|
| 651 |
english_tokens_ids: Represents the prompt tokens (english tokens)
|
| 652 |
Shape (batch_size, num_english_tokens)
|
| 653 |
+
|
| 654 |
bio_tokens_ids: Represents the bio sequences tokens
|
| 655 |
Shape (batch_size, num_bio_sequences, num_bio_tokens)
|
| 656 |
+
|
| 657 |
projection_english_tokens_ids (torch.Tensor):
|
| 658 |
Shape (batch_size, num_english_tokens)
|
| 659 |
+
|
| 660 |
projected_bio_embeddings (projected_bio_embeddings, optional):
|
| 661 |
Shape (batch_size, num_bio_sequencse, ?, embed_dim).
|
| 662 |
Defaults to None.
|
| 663 |
+
|
| 664 |
Returns:
|
| 665 |
dict[str, torch.Tensor] containing:
|
| 666 |
- logits:
|
| 667 |
Shape (batch_size, num_tokens, vocab_size)
|
| 668 |
+
|
| 669 |
- projected_bio_embeddings:
|
| 670 |
Shape (batch_size, num_bio_sequences, ?, embed_dim)
|
| 671 |
"""
|
| 672 |
english_token_ids, bio_token_ids = multi_omics_tokens_ids
|
| 673 |
english_token_ids = english_token_ids.clone()
|
| 674 |
+
bio_token_ids = bio_token_ids.clone()
|
| 675 |
projection_english_tokens_ids = projection_english_tokens_ids.clone()
|
|
|
|
|
|
|
| 676 |
if projected_bio_embeddings is not None:
|
| 677 |
projected_bio_embeddings = projected_bio_embeddings.clone()
|
| 678 |
|
|
|
|
| 740 |
def _create_sinusoidal_positions(self, device: torch.device) -> torch.Tensor:
|
| 741 |
"""
|
| 742 |
Create the sines and cosines for the RoPE.
|
| 743 |
+
|
| 744 |
Returns:
|
| 745 |
Sinusoidal positions of shape (self.max_seq_len, self.dim).
|
| 746 |
"""
|
|
|
|
| 773 |
def _rotate_every_two(self, x: torch.Tensor) -> torch.Tensor:
|
| 774 |
"""
|
| 775 |
Prepare a tensor to apply the RoPE mechanism.
|
| 776 |
+
|
| 777 |
Args:
|
| 778 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 779 |
typically this is the key or query tensor.
|
| 780 |
+
|
| 781 |
Returns:
|
| 782 |
The even indices in the last dimension have their sign flipped.
|
| 783 |
Tensor of shape (batch_size, seq_len, num_heads, head_dim).
|
|
|
|
| 794 |
) -> torch.Tensor:
|
| 795 |
"""
|
| 796 |
Applies rotary embeddings to x.
|
| 797 |
+
|
| 798 |
Args:
|
| 799 |
x: Tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 800 |
typically this is the key or query tensor.
|
| 801 |
sincos: Tuple of sine and cosine tensors for position encoding.
|
| 802 |
+
|
| 803 |
Returns:
|
| 804 |
RoPE embeddings tensor.
|
| 805 |
"""
|
|
|
|
| 817 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 818 |
"""
|
| 819 |
Applies rotary embeddings to k and q.
|
| 820 |
+
|
| 821 |
Args:
|
| 822 |
k: key tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 823 |
q: value tensor of shape (batch_size, seq_len, num_heads, head_dim),
|
| 824 |
positions: optional positions offset useful when caching,
|
| 825 |
+
|
| 826 |
Returns:
|
| 827 |
RoPE embeddings for the keys and values.
|
| 828 |
"""
|
|
|
|
| 1140 |
"""
|
| 1141 |
Builds a batch of causal masks of shape (batch_size, 1, seq_len, seq_len) to feed
|
| 1142 |
to an attention layer.
|
| 1143 |
+
|
| 1144 |
Args:
|
| 1145 |
batch_size: Batch size.
|
| 1146 |
seq_len: Length of the sequences.
|
| 1147 |
+
|
| 1148 |
Returns:
|
| 1149 |
Batch of causal masks.
|
| 1150 |
"""
|
|
|
|
| 1550 |
) -> torch.Tensor:
|
| 1551 |
"""
|
| 1552 |
Computes the embeddings based on the input tokens.
|
| 1553 |
+
|
| 1554 |
Args:
|
| 1555 |
tokens: Input tokens out of the tokenizer of shape (batch_size, seq_len).
|
| 1556 |
attention_mask: Attention mask of shape (batch_size, 1, seq_len, seq_len).
|
| 1557 |
If no mask is provided, a mask by default which equals 1 over all non
|
| 1558 |
pad tokens and 0 over pad tokens is computed.
|
| 1559 |
+
|
| 1560 |
Returns:
|
| 1561 |
Dictionary containing the final embeddings and logits.
|
| 1562 |
"""
|
|
|
|
| 1584 |
) -> torch.Tensor:
|
| 1585 |
"""
|
| 1586 |
Builds a padding mask from a sequence of tokens by masking <pad> in the attention.
|
| 1587 |
+
|
| 1588 |
Args:
|
| 1589 |
tokens: Batch of sequences of shape (batch_size, seq_len).
|
| 1590 |
pad_token_id: Int corresponding to the <pad> token to mask.
|
| 1591 |
+
|
| 1592 |
Returns:
|
| 1593 |
Batch of attention masks, masking out <pad> tokens.
|
| 1594 |
"""
|
|
|
|
| 1615 |
Args:
|
| 1616 |
bio_token_ids (torch.Tensor):
|
| 1617 |
Shape (batch_size, num_bio_tokens)
|
| 1618 |
+
|
| 1619 |
Returns:
|
| 1620 |
torch.Tensor:
|
| 1621 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
|
|
|
| 1725 |
):
|
| 1726 |
"""
|
| 1727 |
Initialize a Perceiver Resampler model.
|
| 1728 |
+
|
| 1729 |
Args:
|
| 1730 |
config: Dataclass containing model hyperparameters.
|
| 1731 |
name: Name for module (custom will break weight loading).
|
|
|
|
| 1854 |
Args:
|
| 1855 |
bio_token_ids (torch.Tensor):
|
| 1856 |
Shape (batch_size, num_bio_tokens)
|
| 1857 |
+
|
| 1858 |
bio_embeddings (torch.Tensor):
|
| 1859 |
Shape (batch_size, num_bio_tokens, embed_dim)
|
| 1860 |
+
|
| 1861 |
english_token_ids (torch.Tensor):
|
| 1862 |
Shape (batch_size, num_english_tokens)
|
| 1863 |
"""
|
|
|
|
| 1900 |
padding_mask = padding_mask[:, None, None, :]
|
| 1901 |
padding_mask = padding_mask.repeat(1, 1, resampled_length, 1) # noqa
|
| 1902 |
return padding_mask
|
|
|
text_generation.py
CHANGED
|
@@ -55,24 +55,19 @@ class TextGenerationPipeline(Pipeline):
|
|
| 55 |
truncation=True,
|
| 56 |
max_length=english_tokens_max_length,
|
| 57 |
).input_ids
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
max_length=bio_tokens_max_length,
|
| 66 |
-
truncation=True,
|
| 67 |
-
).input_ids.unsqueeze(0)
|
| 68 |
|
| 69 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
| 70 |
|
| 71 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
| 72 |
english_tokens = model_inputs["english_tokens"].clone()
|
| 73 |
-
bio_tokens = model_inputs["bio_tokens"]
|
| 74 |
-
if bio_tokens is not None:
|
| 75 |
-
bio_tokens = bio_tokens.clone()
|
| 76 |
projected_bio_embeddings = None
|
| 77 |
|
| 78 |
actual_num_steps = 0
|
|
|
|
| 55 |
truncation=True,
|
| 56 |
max_length=english_tokens_max_length,
|
| 57 |
).input_ids
|
| 58 |
+
bio_tokens = self.bio_tokenizer(
|
| 59 |
+
dna_sequences,
|
| 60 |
+
return_tensors="pt",
|
| 61 |
+
padding="max_length",
|
| 62 |
+
max_length=bio_tokens_max_length,
|
| 63 |
+
truncation=True,
|
| 64 |
+
).input_ids.unsqueeze(0)
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
return {"english_tokens": english_tokens, "bio_tokens": bio_tokens}
|
| 67 |
|
| 68 |
def _forward(self, model_inputs: dict, max_num_tokens_to_decode: int = 50) -> dict:
|
| 69 |
english_tokens = model_inputs["english_tokens"].clone()
|
| 70 |
+
bio_tokens = model_inputs["bio_tokens"].clone()
|
|
|
|
|
|
|
| 71 |
projected_bio_embeddings = None
|
| 72 |
|
| 73 |
actual_num_steps = 0
|