neovalle commited on
Commit
66771c3
·
verified ·
1 Parent(s): e517122

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
4
+ from threading import Thread
5
+ from datetime import datetime
6
+ import pandas as pd
7
+
8
+ # ---------- Config ----------
9
+
10
+ # Small, free chat models that run on CPU in a basic Space (pick one if you like)
11
+ DEFAULT_MODELS = [
12
+ "google/gemma-2-2b-it",
13
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
14
+ "Qwen/Qwen2.5-1.5B-Instruct",
15
+ ]
16
+
17
+ # Cache for loaded models to avoid reloading on each call
18
+ _MODEL_CACHE = {}
19
+
20
+ def _load_model(model_id: str):
21
+ """Load tokenizer and model (cached)."""
22
+ if model_id in _MODEL_CACHE:
23
+ return _MODEL_CACHE[model_id]
24
+
25
+ tok = AutoTokenizer.from_pretrained(model_id, use_fast=True)
26
+ # bfloat16 works on many CPUs and GPUs; fall back to float32 if needed
27
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
28
+
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_id,
31
+ torch_dtype=dtype,
32
+ low_cpu_mem_usage=True,
33
+ device_map="auto",
34
+ )
35
+
36
+ _MODEL_CACHE[model_id] = (tok, model)
37
+ return tok, model
38
+
39
+ def _format_prompt(tokenizer, system_prompt: str, user_prompt: str) -> str:
40
+ """
41
+ Use the model's chat template if available; otherwise
42
+ create a simple system+user concatenation.
43
+ """
44
+ sys = system_prompt.strip() if system_prompt else ""
45
+ usr = user_prompt.strip()
46
+
47
+ if hasattr(tokenizer, "apply_chat_template") and tokenizer.chat_template:
48
+ messages = []
49
+ if sys:
50
+ messages.append({"role": "system", "content": sys})
51
+ messages.append({"role": "user", "content": usr})
52
+ return tokenizer.apply_chat_template(
53
+ messages,
54
+ tokenize=False,
55
+ add_generation_prompt=True,
56
+ )
57
+ # Fallback: a lightweight instruction format
58
+ prompt = ""
59
+ if sys:
60
+ prompt += f"<<SYS>>\n{sys}\n<</SYS>>\n\n"
61
+ prompt += f"<<USER>>\n{usr}\n<</USER>>\n<<ASSISTANT>>\n"
62
+ return prompt
63
+
64
+ def generate_batch(
65
+ model_id: str,
66
+ system_prompt: str,
67
+ prompts_multiline: str,
68
+ max_new_tokens: int,
69
+ temperature: float,
70
+ top_p: float,
71
+ top_k: int,
72
+ repetition_penalty: float,
73
+ ):
74
+ """Generate for multiple user prompts (one per line)."""
75
+ tok, model = _load_model(model_id)
76
+ device = model.device
77
+
78
+ # Split lines, drop empties
79
+ prompts = [p.strip() for p in prompts_multiline.splitlines() if p.strip()]
80
+ if not prompts:
81
+ return pd.DataFrame([{"user_prompt": "", "response": "", "tokens_out": 0}])
82
+
83
+ # Prepare inputs
84
+ formatted = [_format_prompt(tok, system_prompt, p) for p in prompts]
85
+ inputs = tok(
86
+ formatted,
87
+ return_tensors="pt",
88
+ padding=True,
89
+ truncation=True,
90
+ ).to(device)
91
+
92
+ with torch.no_grad():
93
+ outputs = model.generate(
94
+ **inputs,
95
+ max_new_tokens=max_new_tokens,
96
+ do_sample=(temperature > 0.0),
97
+ temperature=temperature if temperature > 0 else None,
98
+ top_p=top_p,
99
+ top_k=top_k if top_k > 0 else None,
100
+ repetition_penalty=repetition_penalty,
101
+ eos_token_id=tok.eos_token_id,
102
+ pad_token_id=tok.eos_token_id,
103
+ )
104
+
105
+ # Slice off the prompt tokens to get only the generated text
106
+ gen_texts = []
107
+ for i in range(outputs.size(0)):
108
+ prompt_len = inputs["input_ids"][i].size(0)
109
+ # Some tokenizers need special handling; safest: decode full and strip prompt
110
+ full = tok.decode(outputs[i], skip_special_tokens=True)
111
+ prompt_only = tok.decode(inputs["input_ids"][i], skip_special_tokens=True)
112
+ # Remove the first occurrence of the prompt text
113
+ resp = full[len(prompt_only):].strip()
114
+ gen_texts.append(resp)
115
+
116
+ df = pd.DataFrame(
117
+ {
118
+ "user_prompt": prompts,
119
+ "response": gen_texts,
120
+ "tokens_out": [len(tok.encode(t)) for t in gen_texts],
121
+ }
122
+ )
123
+ return df
124
+
125
+ def to_csv(df: pd.DataFrame):
126
+ ts = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
127
+ path = f"/tmp/batch_{ts}.csv"
128
+ df.to_csv(path, index=False)
129
+ return path
130
+
131
+ # ---------- UI ----------
132
+
133
+ with gr.Blocks(title="Multi-Prompt Chat (System Prompt Control)") as demo:
134
+ gr.Markdown(
135
+ """
136
+ # 🧪 Multi-Prompt Chat for HF Space
137
+ Pick a small free model, set a **system prompt**, and enter **multiple user prompts** (one per line).
138
+ Click **Generate** to get batched responses as a table (downloadable as CSV).
139
+ """
140
+ )
141
+
142
+ with gr.Row():
143
+ with gr.Column(scale=1):
144
+ model_id = gr.Dropdown(
145
+ choices=DEFAULT_MODELS,
146
+ value=DEFAULT_MODELS[0],
147
+ label="Model",
148
+ info="Free, small instruction-tuned models that run on CPU in a basic Space.",
149
+ )
150
+ system_prompt = gr.Textbox(
151
+ label="System prompt",
152
+ placeholder="e.g., You are an ecolinguistics-aware assistant that prefers concise, actionable answers.",
153
+ lines=5,
154
+ )
155
+ prompts_multiline = gr.Textbox(
156
+ label="User prompts (one per line)",
157
+ placeholder="Write one query per line.\nExample:\nExplain transformers in simple terms\nGive 3 eco-friendly tips for students\nSummarise the benefits of multilingual models",
158
+ lines=10,
159
+ )
160
+
161
+ with gr.Accordion("Generation settings", open=False):
162
+ max_new_tokens = gr.Slider(16, 1024, value=256, step=1, label="max_new_tokens")
163
+ temperature = gr.Slider(0.0, 2.0, value=0.7, step=0.05, label="temperature")
164
+ top_p = gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="top_p")
165
+ top_k = gr.Slider(0, 200, value=40, step=1, label="top_k (0 to disable)")
166
+ repetition_penalty = gr.Slider(1.0, 2.0, value=1.1, step=0.01, label="repetition_penalty")
167
+
168
+ run_btn = gr.Button("Generate", variant="primary")
169
+ csv_btn = gr.Button("Download CSV")
170
+
171
+ with gr.Column(scale=1):
172
+ out_df = gr.Dataframe(
173
+ headers=["user_prompt", "response", "tokens_out"],
174
+ datatype=["str", "str", "number"],
175
+ label="Results",
176
+ wrap=True,
177
+ interactive=False,
178
+ row_count=(0, "dynamic"),
179
+ )
180
+ out_file = gr.File(label="CSV file", visible=False)
181
+
182
+ def _generate(model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty):
183
+ df = generate_batch(
184
+ model_id=model_id,
185
+ system_prompt=system_prompt,
186
+ prompts_multiline=prompts_multiline,
187
+ max_new_tokens=int(max_new_tokens),
188
+ temperature=float(temperature),
189
+ top_p=float(top_p),
190
+ top_k=int(top_k),
191
+ repetition_penalty=float(repetition_penalty),
192
+ )
193
+ return df
194
+
195
+ def _download(df):
196
+ # Gradio passes a dict-like table; normalise to DataFrame
197
+ if isinstance(df, list):
198
+ df = pd.DataFrame(df, columns=["user_prompt", "response", "tokens_out"])
199
+ else:
200
+ df = pd.DataFrame(df)
201
+ path = to_csv(df)
202
+ return gr.File.update(value=path, visible=True)
203
+
204
+ run_btn.click(
205
+ _generate,
206
+ inputs=[model_id, system_prompt, prompts_multiline, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
207
+ outputs=out_df,
208
+ api_name="generate_batch",
209
+ )
210
+
211
+ csv_btn.click(_download, inputs=out_df, outputs=out_file, api_name="download_csv")
212
+
213
+ if __name__ == "__main__":
214
+ demo.launch()