Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request | |
| from pydantic import BaseModel | |
| import torch | |
| import torch.nn as nn | |
| from transformers import DistilBertTokenizer, DistilBertModel, AutoModel, AutoTokenizer | |
| from langdetect import detect | |
| from huggingface_hub import snapshot_download | |
| import os | |
| # Device | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Download model repos from HF Hub | |
| english_repo = snapshot_download("koyu008/English_Toxic_Classifier") | |
| hinglish_repo = snapshot_download("koyu008/HInglish_comment_classifier") | |
| # Tokenizers | |
| english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") | |
| hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base") | |
| # English Model | |
| class ToxicBERT(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased") | |
| self.dropout = nn.Dropout(0.3) | |
| self.classifier = nn.Linear(self.bert.config.hidden_size, 6) | |
| def forward(self, input_ids, attention_mask): | |
| output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0] | |
| return self.classifier(self.dropout(output)) | |
| # Hinglish Model | |
| class HinglishToxicClassifier(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.bert = AutoModel.from_pretrained("xlm-roberta-base") | |
| hidden_size = self.bert.config.hidden_size | |
| self.pool = lambda hidden: torch.cat([ | |
| hidden.mean(dim=1), | |
| hidden.max(dim=1).values | |
| ], dim=1) | |
| self.bottleneck = nn.Sequential( | |
| nn.Linear(2 * hidden_size, hidden_size), | |
| nn.ReLU(), | |
| nn.Dropout(0.2) | |
| ) | |
| self.classifier = nn.Linear(hidden_size, 2) | |
| def forward(self, input_ids, attention_mask): | |
| hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
| pooled = self.pool(hidden) | |
| x = self.bottleneck(pooled) | |
| return self.classifier(x) | |
| # Instantiate and load models | |
| english_model = ToxicBERT().to(device) | |
| english_model.load_state_dict(torch.load(os.path.join(english_repo, "bert_toxic_classifier.pt"), map_location=device)) | |
| english_model.eval() | |
| hinglish_model = HinglishToxicClassifier().to(device) | |
| hinglish_model.load_state_dict(torch.load(os.path.join(hinglish_repo, "best_hinglish_model.pt"), map_location=device)) | |
| hinglish_model.eval() | |
| # Labels | |
| english_labels = ['toxic', 'severe toxic', 'obscene', 'threat', 'insult', 'identity hate'] | |
| hinglish_labels = ['not toxic', 'toxic'] | |
| # FastAPI | |
| app = FastAPI() | |
| from fastapi.middleware.cors import CORSMiddleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Or restrict to your frontend domain | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class TextIn(BaseModel): | |
| text: str | |
| def predict(data: TextIn): | |
| text = data.text | |
| try: | |
| lang = detect(text) | |
| except: | |
| lang = "unknown" | |
| if lang == "en": | |
| tokenizer = english_tokenizer | |
| model = english_model | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.sigmoid(outputs).squeeze().cpu().tolist() | |
| return {"language": "English", "predictions": dict(zip(english_labels, probs))} | |
| else: | |
| tokenizer = hinglish_tokenizer | |
| model = hinglish_model | |
| inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| probs = torch.softmax(outputs, dim=1).squeeze().cpu().tolist() | |
| return {"language": "Hinglish", "predictions": dict(zip(hinglish_labels, probs))} | |
| def root(): | |
| return {"message": "API is running"} | |