Upload modelling.py
Browse files- modelling.py +5 -35
modelling.py
CHANGED
|
@@ -5,8 +5,7 @@ from torch import nn
|
|
| 5 |
from torch.nn import CrossEntropyLoss
|
| 6 |
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
| 7 |
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
| 8 |
-
from transformers.models.modernbert.modeling_modernbert import
|
| 9 |
-
ModernBertPredictionHead
|
| 10 |
|
| 11 |
|
| 12 |
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
@@ -26,10 +25,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
| 26 |
# Initialize weights and apply final processing
|
| 27 |
self.post_init()
|
| 28 |
|
| 29 |
-
@torch.compile(dynamic=True)
|
| 30 |
-
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
|
| 31 |
-
return self.head(output)
|
| 32 |
-
|
| 33 |
def forward(
|
| 34 |
self,
|
| 35 |
input_ids: Optional[torch.Tensor],
|
|
@@ -46,6 +41,7 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
| 46 |
output_attentions: Optional[bool] = None,
|
| 47 |
output_hidden_states: Optional[bool] = None,
|
| 48 |
return_dict: Optional[bool] = None,
|
|
|
|
| 49 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 50 |
r"""
|
| 51 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -60,20 +56,6 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
| 60 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 61 |
self._maybe_set_compile()
|
| 62 |
|
| 63 |
-
# Get sequence length and batch size if not provided
|
| 64 |
-
# if batch_size is None or seq_len is None:
|
| 65 |
-
# batch_size, seq_len = input_ids.shape[:2]
|
| 66 |
-
|
| 67 |
-
# # Handle Flash Attention 2 unpadding
|
| 68 |
-
# if self.config._attn_implementation == "flash_attention_2":
|
| 69 |
-
# if indices is None and cu_seqlens is None and max_seqlen is None:
|
| 70 |
-
# if attention_mask is None:
|
| 71 |
-
# attention_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.bool)
|
| 72 |
-
# with torch.no_grad():
|
| 73 |
-
# input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
|
| 74 |
-
# inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids
|
| 75 |
-
# )
|
| 76 |
-
|
| 77 |
outputs = self.model(
|
| 78 |
input_ids,
|
| 79 |
attention_mask=attention_mask,
|
|
@@ -90,24 +72,12 @@ class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
| 90 |
)
|
| 91 |
|
| 92 |
sequence_output = outputs[0]
|
| 93 |
-
sequence_output = (
|
| 94 |
-
self.drop(self.compiled_head(sequence_output))
|
| 95 |
-
if self.config.reference_compile
|
| 96 |
-
else self.drop(self.head(sequence_output))
|
| 97 |
-
)
|
| 98 |
-
# sequence_output = self.drop(self.head(sequence_output))
|
| 99 |
|
| 100 |
logits = self.qa_outputs(sequence_output)
|
| 101 |
start_logits, end_logits = logits.split(1, dim=-1)
|
| 102 |
-
start_logits = start_logits.squeeze(-1)
|
| 103 |
-
end_logits = end_logits.squeeze(-1)
|
| 104 |
-
|
| 105 |
-
# # Handle Flash Attention 2 padding
|
| 106 |
-
# if self.config._attn_implementation == "flash_attention_2":
|
| 107 |
-
# start_logits = _pad_modernbert_output(inputs=start_logits, indices=indices, batch=batch_size,
|
| 108 |
-
# seqlen=seq_len)
|
| 109 |
-
# end_logits = _pad_modernbert_output(inputs=end_logits, indices=indices, batch=batch_size,
|
| 110 |
-
# seqlen=seq_len)
|
| 111 |
|
| 112 |
total_loss = None
|
| 113 |
if start_positions is not None and end_positions is not None:
|
|
|
|
| 5 |
from torch.nn import CrossEntropyLoss
|
| 6 |
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
| 7 |
from transformers.modeling_outputs import QuestionAnsweringModelOutput
|
| 8 |
+
from transformers.models.modernbert.modeling_modernbert import ModernBertPredictionHead
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class ModernBertForQuestionAnswering(ModernBertPreTrainedModel):
|
|
|
|
| 25 |
# Initialize weights and apply final processing
|
| 26 |
self.post_init()
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def forward(
|
| 29 |
self,
|
| 30 |
input_ids: Optional[torch.Tensor],
|
|
|
|
| 41 |
output_attentions: Optional[bool] = None,
|
| 42 |
output_hidden_states: Optional[bool] = None,
|
| 43 |
return_dict: Optional[bool] = None,
|
| 44 |
+
**kwargs,
|
| 45 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 46 |
r"""
|
| 47 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 56 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 57 |
self._maybe_set_compile()
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
outputs = self.model(
|
| 60 |
input_ids,
|
| 61 |
attention_mask=attention_mask,
|
|
|
|
| 72 |
)
|
| 73 |
|
| 74 |
sequence_output = outputs[0]
|
| 75 |
+
sequence_output = self.drop(self.head(sequence_output))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
logits = self.qa_outputs(sequence_output)
|
| 78 |
start_logits, end_logits = logits.split(1, dim=-1)
|
| 79 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
| 80 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
total_loss = None
|
| 83 |
if start_positions is not None and end_positions is not None:
|