Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from langdetect import detect | |
| import torch | |
| import torch.nn as nn | |
| from transformers import ( | |
| DistilBertTokenizer, DistilBertModel, | |
| AutoTokenizer, AutoModel | |
| ) | |
| # ==== Model Classes ==== | |
| class ToxicBERT(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, 6) | |
| def forward(self, input_ids, attention_mask): | |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
| return self.classifier(self.dropout(output)) | |
| class HinglishToxicClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained("xlm-roberta-base") | |
| hidden_size = self.bert.config.hidden_size | |
| self.pool = lambda hidden: torch.cat([ | |
| hidden.mean(dim=1), | |
| hidden.max(dim=1).values | |
| ], dim=1) | |
| self.bottleneck = nn.Sequential( | |
| nn.Linear(2 * hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(0.2) | |
| ) | |
| self.classifier = nn.Linear(hidden_size, 2) | |
| def forward(self, input_ids, attention_mask): | |
| hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
| pooled = self.pool(hidden) | |
| x = self.bottleneck(pooled) | |
| return self.classifier(x) | |
| # ==== Load Tokenizers ==== | |
| english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
| hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") | |
| # ==== Load Models from Hugging Face Hub ==== | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| english_model = ToxicBERT() | |
| eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt" | |
| english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device)) | |
| english_model.eval().to(device) | |
| hinglish_model = HinglishToxicClassifier() | |
| hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt" | |
| hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device)) | |
| hinglish_model.eval().to(device) | |
| # ==== FastAPI setup ==== | |
| app = FastAPI() | |
| class InputText(BaseModel): | |
| text: str | |
| def predict(input: InputText): | |
| text = input.text.strip() | |
| if not text: | |
| raise HTTPException(status_code=400, detail="Input text cannot be empty") | |
| # Language detection | |
| try: | |
| lang = detect(text) | |
| except: | |
| lang = "und" | |
| if lang == "en": | |
| model = english_model | |
| tokenizer = english_tokenizer | |
| labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"] | |
| else: | |
| model = hinglish_model | |
| tokenizer = hinglish_tokenizer | |
| labels = ["not toxic", "toxic"] | |
| # Tokenization | |
| inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs, dim=1).squeeze().tolist() | |
| response = { | |
| "language": "english" if lang == "en" else "hinglish", | |
| "prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)} | |
| } | |
| return response | |