NghiTran1009 commited on
Commit
65ad6d0
·
verified ·
1 Parent(s): f4ec716

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -1,29 +1,38 @@
1
  import gradio as gr
2
  import uuid
3
- import shutil
4
  import os
 
 
 
5
 
6
  CO_SPY_DIR = "./CO-SPY"
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def gradio_scan_image(image):
9
- # Lưu ảnh upload vào CO-SPY với tên ngẫu nhiên
10
  dst_path = os.path.join(CO_SPY_DIR, f"input_{uuid.uuid4().hex}.png")
11
  image.save(dst_path)
12
-
13
- # Import các thư viện cần thiết cho scan
14
- from evaluate import scan, args # giả sử args đã được khởi tạo trong evaluate.py
15
-
16
- # Chạy trực tiếp scan
17
  probability, label = scan(args, dst_path)
18
  return f"Prediction: {probability:.3f}, Label: {label}"
19
 
20
- # Khởi tạo giao diện Gradio
21
  iface = gr.Interface(
22
  fn=gradio_scan_image,
23
  inputs=gr.Image(type="pil"),
24
  outputs="text",
25
- title="CO-SPY Scan Synthetic Image",
26
- description="Upload ảnh và CO-SPY sẽ trả về xác suất và nhãn."
27
  )
28
 
29
  iface.launch()
 
1
  import gradio as gr
2
  import uuid
 
3
  import os
4
+ from PIL import Image
5
+ import torch
6
+ from evaluate import scan, Detector, seed_torch # import hàm scan đã sửa
7
 
8
  CO_SPY_DIR = "./CO-SPY"
9
 
10
+ # --- Khởi tạo args và seed trước ---
11
+ class Args:
12
+ gpu = 0
13
+ device = f"cuda:{gpu}" if torch.cuda.is_available() else "cpu"
14
+ semantic_weights_path = "pretrained/semantic_weights.pth"
15
+ artifact_weights_path = "pretrained/artifact_weights.pth"
16
+ classifier_weights_path = "pretrained/classifier_weights.pth"
17
+ batch_size = 32
18
+ seed = 1024
19
+
20
+ args = Args()
21
+ seed_torch(args.seed)
22
+
23
+ # --- Hàm Gradio ---
24
  def gradio_scan_image(image):
 
25
  dst_path = os.path.join(CO_SPY_DIR, f"input_{uuid.uuid4().hex}.png")
26
  image.save(dst_path)
27
+
 
 
 
 
28
  probability, label = scan(args, dst_path)
29
  return f"Prediction: {probability:.3f}, Label: {label}"
30
 
 
31
  iface = gr.Interface(
32
  fn=gradio_scan_image,
33
  inputs=gr.Image(type="pil"),
34
  outputs="text",
35
+ title="CO-SPY Scan Synthetic Image"
 
36
  )
37
 
38
  iface.launch()