File size: 4,057 Bytes
ed00d52
 
 
3f5bb73
ed00d52
 
 
 
 
31f4c0d
ed00d52
 
 
 
 
 
 
 
 
 
 
 
 
 
3f5bb73
ed00d52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
830c0ea
ed00d52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithPast
from typing import Optional, Tuple, Union
from ukraine.research.transformer.transformer import Transformer
from ukraine.research.transformer.layers import SiLUFeedForward
from ukraine.research.transformer.masking import generate_square_subsequent_mask
from .configuration_lime import LIMEConfig


def make_ff(config: LIMEConfig):
    return SiLUFeedForward(
        d_model=config.d_model,
        dff=config.dff,
        multiple_of=config.multiple_of
    )


def make_norm(config: LIMEConfig):
    return nn.RMSNorm(config.d_model)


class LIMEForCausalLM(PreTrainedModel, GenerationMixin):
    config_class = LIMEConfig
    base_model_prefix = "lime"
    _tied_weights_keys = ["transformer.output_fc.weight"]

    def __init__(self, config: LIMEConfig):
        super().__init__(config)
        self.config = config

        self.transformer = Transformer(
            num_encoder_layers=config.num_encoder_layers,
            num_decoder_layers=config.num_decoder_layers,
            d_model=config.d_model,
            num_heads=config.num_heads,
            input_vocab_size=config.vocab_size,
            target_vocab_size=config.vocab_size,
            dropout_rate=config.dropout_rate,
            ff_factory=lambda: make_ff(config),
            norm_factory=lambda: make_norm(config),
            pad_token_id=config.pad_token_id,
            use_encoder=config.use_encoder,
            use_flash=config.use_flash
        )

        self.post_init()

    # For transformers library
    def get_input_embeddings(self):
        return self.transformer.decoder.embedding

    def set_input_embeddings(self, value):
        self.transformer.decoder.embedding = value

    def get_output_embeddings(self):
        return self.transformer.output_fc

    def set_output_embeddings(self, new_embeddings):
        self.transformer.output_fc = new_embeddings

    def _tie_weights(self):
        if self.config.tie_word_embeddings:
            self._tie_or_clone_weights(
                self.transformer.output_fc,
                self.get_input_embeddings()
            )

    def forward(
            self,
            input_ids: torch.LongTensor,
            attention_mask: Optional[torch.Tensor] = None,
            token_type_ids: Optional[torch.LongTensor] = None,
            labels: Optional[torch.LongTensor] = None,
            return_dict: Optional[bool] = None,
            **kwargs
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        batch_size, seq_len = input_ids.shape
        device = input_ids.device

        tgt_mask = generate_square_subsequent_mask(seq_len, device)

        # If we are planning to train the model.
        if labels is not None:
            tgt_key_padding_mask = input_ids.eq(self.config.pad_token_id)
        # For inference we do not need it.
        else:
            tgt_key_padding_mask = None

        logits, _ = self.transformer(
            src=input_ids,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=tgt_key_padding_mask
        )

        loss = None
        if labels is not None:
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = labels[:, 1:].contiguous()
            # This ignore index was used during SFT training.
            criterion = nn.CrossEntropyLoss(ignore_index=-100)
            loss = criterion(
                shift_logits.reshape(-1, self.config.vocab_size),
                shift_labels.reshape(-1)
            )

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=None,
            hidden_states=None,
            attentions=None
        )