neovalle commited on
Commit
9bcd9ad
·
verified ·
1 Parent(s): db17601

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -68
app.py CHANGED
@@ -1,20 +1,35 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
- from threading import Thread
5
  from datetime import datetime
 
 
6
  import pandas as pd
 
 
7
 
8
- # ---------- Config ----------
 
 
9
 
10
- # Small, free chat models that run on CPU in a basic Space (pick one if you like)
11
  DEFAULT_MODELS = [
 
12
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
13
  "Qwen/Qwen2.5-1.5B-Instruct",
14
  ]
15
 
16
- # Cache for loaded models to avoid reloading on each call
17
- _MODEL_CACHE = {}
 
 
 
 
 
 
 
 
 
 
18
 
19
  def _load_model(model_id: str):
20
  """Load tokenizer and model (cached)."""
@@ -22,26 +37,36 @@ def _load_model(model_id: str):
22
  return _MODEL_CACHE[model_id]
23
 
24
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
25
- # bfloat16 works on many CPUs and GPUs; fall back to float32 if needed
26
- dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
27
 
 
 
 
 
 
 
 
 
 
28
  model = AutoModelForCausalLM.from_pretrained(
29
  model_id,
30
  torch_dtype=dtype,
31
  low_cpu_mem_usage=True,
32
  device_map="auto",
33
  )
 
 
 
34
 
35
  _MODEL_CACHE[model_id] = (tok, model)
36
  return tok, model
37
 
 
38
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
39
  """
40
- Use the model's chat template if available; otherwise
41
- create a simple system+user concatenation.
42
  """
43
- sys = system_prompt.strip() if system_prompt else ""
44
- usr = user_prompt.strip()
45
 
46
  if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
47
  messages = []
@@ -53,12 +78,11 @@ def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
53
  tokenize=False,
54
  add_generation_prompt=True,
55
  )
56
- # Fallback: a lightweight instruction format
57
- prompt = ""
58
- if sys:
59
- prompt += f"<<SYS>>\n{sys}\n<</SYS>>\n\n"
60
- prompt += f"<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
61
- return prompt
62
 
63
  def generate_batch(
64
  model_id: str,
@@ -69,72 +93,72 @@ def generate_batch(
69
  top_p: float,
70
  top_k: int,
71
  repetition_penalty: float,
72
- ):
73
- """Generate for multiple user prompts (one per line)."""
74
  tok, model = _load_model(model_id)
75
  device = model.device
76
 
77
- # Split lines, drop empties
78
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
79
  if not prompts:
80
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
81
 
82
- # Prepare inputs
83
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
84
- inputs = tok(
 
85
  formatted,
86
  return_tensors="pt",
87
  padding=True,
88
  truncation=True,
89
  ).to(device)
90
 
 
 
 
91
  with torch.no_grad():
92
- outputs = model.generate(
93
- **inputs,
94
- max_new_tokens=max_new_tokens,
95
  do_sample=(temperature > 0.0),
96
- temperature=temperature if temperature > 0 else None,
97
- top_p=top_p,
98
- top_k=top_k if top_k > 0 else None,
99
- repetition_penalty=repetition_penalty,
100
  eos_token_id=tok.eos_token_id,
101
- pad_token_id=tok.eos_token_id,
102
  )
103
 
104
- # Slice off the prompt tokens to get only the generated text
105
- gen_texts = []
106
- for i in range(outputs.size(0)):
107
- prompt_len = inputs["input_ids"][i].size(0)
108
- # Some tokenizers need special handling; safest: decode full and strip prompt
109
- full = tok.decode(outputs[i], skip_special_tokens=True)
110
- prompt_only = tok.decode(inputs["input_ids"][i], skip_special_tokens=True)
111
- # Remove the first occurrence of the prompt text
112
- resp = full[len(prompt_only):].strip()
113
- gen_texts.append(resp)
114
 
115
  df = pd.DataFrame(
116
  {
117
  "user_prompt": prompts,
118
- "response": gen_texts,
119
- "tokens_out": [len(tok.encode(t)) for t in gen_texts],
120
  }
121
  )
122
  return df
123
 
124
- def to_csv(df: pd.DataFrame):
125
- ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
126
- path = f"/tmp/batch_{ts}.csv"
127
- df.to_csv(path, index=False)
128
- return path
129
 
130
- # ---------- UI ----------
 
 
131
 
132
  with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
133
  gr.Markdown(
134
  """
135
- # 🧪 Multi-Prompt Chat for HF Space
136
- Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
137
- Click **Generate** to get batched responses as a table (downloadable as CSV).
138
  """
139
  )
140
 
@@ -153,7 +177,7 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
153
  )
154
  prompts_multiline = gr.Textbox(
155
  label="User prompts (one per line)",
156
- placeholder="Write one query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips for students\nSummarise the benefits of multilingual models",
157
  lines=10,
158
  )
159
 
@@ -161,13 +185,15 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
161
  max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
162
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
163
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
164
- top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 to disable)")
165
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
166
 
167
  run_btn = gr.Button("Generate", variant="primary")
168
- csv_btn = gr.Button("Download CSV")
169
 
170
  with gr.Column(scale=1):
 
 
 
171
  out_df = gr.Dataframe(
172
  headers=["user_prompt", "response", "tokens_out"],
173
  datatype=["str", "str", "number"],
@@ -175,11 +201,18 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
175
  wrap=True,
176
  interactive=False,
177
  row_count=(0, "dynamic"),
178
- type="pandas",
 
 
 
 
 
 
179
  )
180
- out_file = gr.File(label="CSV file", visible=False)
181
 
182
- def _generate(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
 
 
183
  df = generate_batch(
184
  model_id=model_id,
185
  system_prompt=system_prompt,
@@ -190,20 +223,23 @@ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
190
  top_k=int(top_k),
191
  repetition_penalty=float(repetition_penalty),
192
  )
193
- return df
194
 
195
- def _download(df):
196
- path = to_csv(df)
197
- return gr.File.update(value=path, visible=True)
198
-
199
  run_btn.click(
200
- _generate,
201
  inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
202
- outputs=out_df,
203
  api_name="generate_batch",
204
  )
205
 
206
- csv_btn.click(_download, inputs=out_df, outputs=out_file, api_name="download_csv")
 
 
 
 
 
 
 
207
 
208
  if __name__ == "__main__":
209
  demo.launch()
 
1
+ # app.py
2
+ import io
 
 
3
  from datetime import datetime
4
+
5
+ import gradio as gr
6
  import pandas as pd
7
+ import torch
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
 
10
+ # ----------------------------
11
+ # Config
12
+ # ----------------------------
13
 
14
+ # Small, free, instruction-tuned models that run on CPU in a Basic Space.
15
  DEFAULT_MODELS = [
16
+ "google/gemma-2-2b-it",
17
  "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
18
  "Qwen/Qwen2.5-1.5B-Instruct",
19
  ]
20
 
21
+ _MODEL_CACHE = {} # (tokenizer, model) cache
22
+
23
+
24
+ # ----------------------------
25
+ # Utilities
26
+ # ----------------------------
27
+
28
+ def df_to_csv_bytes(df: pd.DataFrame) -> bytes:
29
+ buf = io.StringIO()
30
+ df.to_csv(buf, index=False)
31
+ return buf.getvalue().encode("utf-8")
32
+
33
 
34
  def _load_model(model_id: str):
35
  """Load tokenizer and model (cached)."""
 
37
  return _MODEL_CACHE[model_id]
38
 
39
  tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
 
 
40
 
41
+ # Ensure we have a pad token to avoid warnings in generate
42
+ if tok.pad_token is None:
43
+ # Prefer eos_token, else add a pad token
44
+ if tok.eos_token is not None:
45
+ tok.pad_token = tok.eos_token
46
+ else:
47
+ tok.add_special_tokens({"pad_token": "<|pad|>"})
48
+
49
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
50
  model = AutoModelForCausalLM.from_pretrained(
51
  model_id,
52
  torch_dtype=dtype,
53
  low_cpu_mem_usage=True,
54
  device_map="auto",
55
  )
56
+ # If we added a pad token, resize embeddings
57
+ if model.get_input_embeddings().num_embeddings != len(tok):
58
+ model.resize_token_embeddings(len(tok))
59
 
60
  _MODEL_CACHE[model_id] = (tok, model)
61
  return tok, model
62
 
63
+
64
  def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
65
  """
66
+ Prefer the model's chat template. Fallback to a light instruction format.
 
67
  """
68
+ sys = (system_prompt or "").strip()
69
+ usr = (user_prompt or "").strip()
70
 
71
  if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
72
  messages = []
 
78
  tokenize=False,
79
  add_generation_prompt=True,
80
  )
81
+
82
+ # Fallback format
83
+ prefix = f"<<SYS>>\n{sys}\n<</SYS>>\n\n" if sys else ""
84
+ return f"{prefix}<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
85
+
 
86
 
87
  def generate_batch(
88
  model_id: str,
 
93
  top_p: float,
94
  top_k: int,
95
  repetition_penalty: float,
96
+ ) -> pd.DataFrame:
97
+ """Generate responses for multiple user prompts (one per line)."""
98
  tok, model = _load_model(model_id)
99
  device = model.device
100
 
101
+ # Split lines, discard empties
102
  prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
103
  if not prompts:
104
  return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
105
 
106
+ # Build formatted prompts per model
107
  formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
108
+
109
+ enc = tok(
110
  formatted,
111
  return_tensors="pt",
112
  padding=True,
113
  truncation=True,
114
  ).to(device)
115
 
116
+ # True prompt lengths per row (use attention mask sum to ignore padding)
117
+ prompt_lens = enc["attention_mask"].sum(dim=1)
118
+
119
  with torch.no_grad():
120
+ gen = model.generate(
121
+ **enc,
122
+ max_new_tokens=int(max_new_tokens),
123
  do_sample=(temperature > 0.0),
124
+ temperature=float(temperature) if temperature > 0 else None,
125
+ top_p=float(top_p),
126
+ top_k=int(top_k) if int(top_k) > 0 else None,
127
+ repetition_penalty=float(repetition_penalty),
128
  eos_token_id=tok.eos_token_id,
129
+ pad_token_id=tok.pad_token_id,
130
  )
131
 
132
+ # Slice generated tokens per row using actual prompt length
133
+ responses = []
134
+ tokens_out = []
135
+ for i in range(gen.size(0)):
136
+ start = int(prompt_lens[i].item())
137
+ gen_ids = gen[i, start:]
138
+ text = tok.decode(gen_ids, skip_special_tokens=True).strip()
139
+ responses.append(text)
140
+ tokens_out.append(len(gen_ids))
 
141
 
142
  df = pd.DataFrame(
143
  {
144
  "user_prompt": prompts,
145
+ "response": responses,
146
+ "tokens_out": tokens_out,
147
  }
148
  )
149
  return df
150
 
 
 
 
 
 
151
 
152
+ # ----------------------------
153
+ # Gradio UI
154
+ # ----------------------------
155
 
156
  with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
157
  gr.Markdown(
158
  """
159
+ # 🧪 Multi-Prompt Chat for HF Space
160
+ Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
161
+ Click **Generate** to get batched responses, then **Download CSV** for offline use.
162
  """
163
  )
164
 
 
177
  )
178
  prompts_multiline = gr.Textbox(
179
  label="User prompts (one per line)",
180
+ placeholder="One query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips for students\nSummarise the benefits of multilingual models",
181
  lines=10,
182
  )
183
 
 
185
  max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
186
  temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
187
  top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
188
+ top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 disables)")
189
  repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
190
 
191
  run_btn = gr.Button("Generate", variant="primary")
 
192
 
193
  with gr.Column(scale=1):
194
+ # Keep last results for stable downloads
195
+ state_df = gr.State(value=None)
196
+
197
  out_df = gr.Dataframe(
198
  headers=["user_prompt", "response", "tokens_out"],
199
  datatype=["str", "str", "number"],
 
201
  wrap=True,
202
  interactive=False,
203
  row_count=(0, "dynamic"),
204
+ type="pandas", # ensure callbacks get a pandas DataFrame
205
+ )
206
+
207
+ download_btn = gr.DownloadButton(
208
+ label="Download CSV",
209
+ value=None, # we update this with bytes on demand
210
+ file_name="batch.csv",
211
  )
 
212
 
213
+ # -------- Callbacks --------
214
+
215
+ def _generate_cb(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
216
  df = generate_batch(
217
  model_id=model_id,
218
  system_prompt=system_prompt,
 
223
  top_k=int(top_k),
224
  repetition_penalty=float(repetition_penalty),
225
  )
226
+ return df, df # show in table, also store in state
227
 
 
 
 
 
228
  run_btn.click(
229
+ _generate_cb,
230
  inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
231
+ outputs=[out_df, state_df],
232
  api_name="generate_batch",
233
  )
234
 
235
+ def _prepare_csv_cb(df_state):
236
+ if df_state is None or len(df_state) == 0:
237
+ df_state = pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
238
+ csv_bytes = df_to_csv_bytes(df_state)
239
+ ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
240
+ return gr.DownloadButton.update(value=csv_bytes, file_name=f"batch_{ts}.csv")
241
+
242
+ download_btn.click(_prepare_csv_cb, inputs=[state_df], outputs=[download_btn], api_name="download_csv")
243
 
244
  if __name__ == "__main__":
245
  demo.launch()