koyu008 commited on
Commit
2b470ab
·
verified ·
1 Parent(s): 8b0dba6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -103
app.py CHANGED
@@ -1,103 +1,92 @@
1
- from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from langdetect import detect
4
- import torch
5
- import torch.nn as nn
6
- from transformers import (
7
- DistilBertTokenizer, DistilBertModel,
8
- AutoTokenizer, AutoModel
9
- )
10
-
11
- # ==== Model Classes ====
12
-
13
- class ToxicBERT(nn.Module):
14
- def __init__(self):
15
- super().__init__()
16
- self.bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
17
- self.dropout = nn.Dropout(0.3)
18
- self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
19
-
20
- def forward(self, input_ids, attention_mask):
21
- output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
22
- return self.classifier(self.dropout(output))
23
-
24
- class HinglishToxicClassifier(nn.Module):
25
- def __init__(self):
26
- super().__init__()
27
- self.bert = AutoModel.from_pretrained("xlm-roberta-base")
28
- hidden_size = self.bert.config.hidden_size
29
-
30
- self.pool = lambda hidden: torch.cat([
31
- hidden.mean(dim=1),
32
- hidden.max(dim=1).values
33
- ], dim=1)
34
-
35
- self.bottleneck = nn.Sequential(
36
- nn.Linear(2 * hidden_size, hidden_size),
37
- nn.ReLU(),
38
- nn.Dropout(0.2)
39
- )
40
- self.classifier = nn.Linear(hidden_size, 2)
41
-
42
- def forward(self, input_ids, attention_mask):
43
- hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
44
- pooled = self.pool(hidden)
45
- x = self.bottleneck(pooled)
46
- return self.classifier(x)
47
-
48
- # ==== Load Tokenizers ====
49
- english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
50
- hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
51
-
52
- # ==== Load Models from Hugging Face Hub ====
53
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
-
55
- english_model = ToxicBERT()
56
- eng_url = "https://huggingface.co/koyu008/English_Toxic_Classifier/resolve/main/bert_toxic_classifier.pt"
57
- english_model.load_state_dict(torch.hub.load_state_dict_from_url(eng_url, map_location=device))
58
- english_model.eval().to(device)
59
-
60
- hinglish_model = HinglishToxicClassifier()
61
- hin_url = "https://huggingface.co/koyu008/HInglish_comment_classifier/resolve/main/best_hinglish_model.pt"
62
- hinglish_model.load_state_dict(torch.hub.load_state_dict_from_url(hin_url, map_location=device))
63
- hinglish_model.eval().to(device)
64
-
65
- # ==== FastAPI setup ====
66
-
67
- app = FastAPI()
68
-
69
- class InputText(BaseModel):
70
- text: str
71
-
72
- @app.post("/predict")
73
- def predict(input: InputText):
74
- text = input.text.strip()
75
- if not text:
76
- raise HTTPException(status_code=400, detail="Input text cannot be empty")
77
-
78
- # Language detection
79
- try:
80
- lang = detect(text)
81
- except:
82
- lang = "und"
83
-
84
- if lang == "en":
85
- model = english_model
86
- tokenizer = english_tokenizer
87
- labels = ["toxic", "severe toxic", "obscene", "threat", "insult", "identity hate"]
88
- else:
89
- model = hinglish_model
90
- tokenizer = hinglish_tokenizer
91
- labels = ["not toxic", "toxic"]
92
-
93
- # Tokenization
94
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
95
- with torch.no_grad():
96
- outputs = model(**inputs)
97
- probs = torch.softmax(outputs, dim=1).squeeze().tolist()
98
-
99
- response = {
100
- "language": "english" if lang == "en" else "hinglish",
101
- "prediction": {label: float(round(prob, 4)) for label, prob in zip(labels, probs)}
102
- }
103
- return response
 
1
+ from fastapi import FastAPI, Request
2
+ from pydantic import BaseModel
3
+ from langdetect import detect
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import DistilBertModel, AutoModel, AutoTokenizer, DistilBertTokenizer
7
+ from huggingface_hub import snapshot_download
8
+ import os
9
+
10
+ app = FastAPI()
11
+
12
+ # Use local cache folder for downloaded models
13
+ os.environ["TRANSFORMERS_CACHE"] = "/app/.hf_cache"
14
+ os.makedirs("/app/.hf_cache", exist_ok=True)
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+
18
+ # -------------------------------
19
+ # Model Classes
20
+ # -------------------------------
21
+
22
+ class ToxicBERT(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.bert = DistilBertModel.from_pretrained(snapshot_download("koyu008/English_Toxic_Classifier"))
26
+ self.dropout = nn.Dropout(0.3)
27
+ self.classifier = nn.Linear(self.bert.config.hidden_size, 6)
28
+
29
+ def forward(self, input_ids, attention_mask):
30
+ output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0]
31
+ return self.classifier(self.dropout(output))
32
+
33
+ class HinglishToxicClassifier(nn.Module):
34
+ def __init__(self):
35
+ super().__init__()
36
+ self.bert = AutoModel.from_pretrained(snapshot_download("koyu008/Hinglish_comment_classifier"))
37
+ hidden_size = self.bert.config.hidden_size
38
+ self.pool = lambda hidden: torch.cat([
39
+ hidden.mean(dim=1),
40
+ hidden.max(dim=1).values
41
+ ], dim=1)
42
+ self.bottleneck = nn.Sequential(
43
+ nn.Linear(2 * hidden_size, hidden_size),
44
+ nn.ReLU(),
45
+ nn.Dropout(0.2)
46
+ )
47
+ self.classifier = nn.Linear(hidden_size, 2)
48
+
49
+ def forward(self, input_ids, attention_mask):
50
+ hidden = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
51
+ pooled = self.pool(hidden)
52
+ x = self.bottleneck(pooled)
53
+ return self.classifier(x)
54
+
55
+ # -------------------------------
56
+ # Load Models and Tokenizers
57
+ # -------------------------------
58
+
59
+ english_model = ToxicBERT().to(device)
60
+ english_model.load_state_dict(torch.load("bert_toxic_classifier.pt", map_location=device))
61
+ english_model.eval()
62
+ english_tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
63
+
64
+ hinglish_model = HinglishToxicClassifier().to(device)
65
+ hinglish_model.load_state_dict(torch.load("best_hinglish_model.pt", map_location=device))
66
+ hinglish_model.eval()
67
+ hinglish_tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
68
+
69
+ # -------------------------------
70
+ # Request & Inference
71
+ # -------------------------------
72
+
73
+ class InputText(BaseModel):
74
+ text: str
75
+
76
+ @app.post("/predict")
77
+ async def predict(input: InputText):
78
+ text = input.text
79
+ lang = detect(text)
80
+
81
+ if lang == "en":
82
+ inputs = english_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
83
+ with torch.no_grad():
84
+ logits = english_model(**inputs)
85
+ probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
86
+ return {"language": "english", "classes": ["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"], "probabilities": probs}
87
+ else:
88
+ inputs = hinglish_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
89
+ with torch.no_grad():
90
+ logits = hinglish_model(**inputs)
91
+ probs = torch.softmax(logits, dim=1).cpu().numpy().tolist()[0]
92
+ return {"language": "hinglish", "classes": ["toxic", "non-toxic"], "probabilities": probs}