bosh94 commited on
Commit
0718fad
·
verified ·
1 Parent(s): f79bf21

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -141
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import os
2
- import inspect
3
  import numpy as np
4
  import pandas as pd
5
  import gradio as gr
@@ -8,54 +7,29 @@ import torch
8
 
9
  from chronos import Chronos2Pipeline
10
 
 
11
  # =========================
12
  # Config
13
  # =========================
14
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
15
  DATA_DIR = "data"
16
 
 
17
  # =========================
18
- # Helpers: files & device
19
  # =========================
20
  def available_test_csv():
21
  if not os.path.isdir(DATA_DIR):
22
  return []
23
  return sorted(f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv"))
24
 
 
25
  def pick_device(ui_choice: str) -> str:
26
  if (ui_choice or "").startswith("cuda") and torch.cuda.is_available():
27
  return "cuda"
28
  return "cpu"
29
 
30
- # =========================
31
- # Model cache
32
- # =========================
33
- _PIPELINE = None
34
- _PIPELINE_META = {}
35
 
36
- def get_pipeline(model_id: str, device: str):
37
- """
38
- Caches the pipeline across calls to avoid re-downloading and re-loading.
39
- """
40
- global _PIPELINE, _PIPELINE_META
41
-
42
- model_id = (model_id or MODEL_ID_DEFAULT).strip()
43
- device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
44
-
45
- if (
46
- _PIPELINE is None
47
- or _PIPELINE_META.get("model_id") != model_id
48
- or _PIPELINE_META.get("device") != device
49
- ):
50
- # Chronos-2 pipeline
51
- _PIPELINE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
52
- _PIPELINE_META = {"model_id": model_id, "device": device}
53
-
54
- return _PIPELINE
55
-
56
- # =========================
57
- # Data generation/loading
58
- # =========================
59
  def make_sample_series(n, seed, trend, season_period, season_amp, noise):
60
  rng = np.random.default_rng(int(seed))
61
  t = np.arange(int(n))
@@ -64,30 +38,30 @@ def make_sample_series(n, seed, trend, season_period, season_amp, noise):
64
  + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
65
  + rng.normal(0.0, float(noise), size=len(t))
66
  )
67
- # shift up if negative (not required, but keeps nice plots)
68
  mn = float(np.min(y))
69
  if mn < 0:
70
  y = y - mn
71
  return y.astype(np.float32)
72
 
 
73
  def load_series_from_csv(path_or_file, column=None):
74
  df = pd.read_csv(path_or_file)
75
-
76
  if df.shape[1] == 0:
77
  raise ValueError("CSV vuoto o non leggibile.")
78
 
79
  col = (column or "").strip()
80
  if col == "":
 
81
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
 
82
  if not numeric_cols:
83
- # try coercion to numeric on all columns (sometimes dtype is object)
84
- numeric_cols = []
85
  for c in df.columns:
86
  coerced = pd.to_numeric(df[c], errors="coerce")
87
  if coerced.notna().sum() >= 10:
88
  numeric_cols.append(c)
89
- if not numeric_cols:
90
- raise ValueError("Nessuna colonna numerica nel CSV. Specifica una colonna con numeri.")
91
  col = numeric_cols[0]
92
 
93
  if col not in df.columns:
@@ -99,76 +73,78 @@ def load_series_from_csv(path_or_file, column=None):
99
 
100
  return y.astype(np.float32), col
101
 
 
102
  # =========================
103
- # Chronos2 predict normalization
104
  # =========================
105
- def _extract_samples(pred_out):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  """
107
- Chronos2Pipeline.predict may return:
108
- - numpy array / list -> samples
109
- - dict with 'samples'
110
- - object with attribute 'samples'
111
- This returns np.ndarray of shape (n_draws, pred_len) or (pred_len,) if only one draw.
112
  """
113
- if isinstance(pred_out, np.ndarray):
114
- return pred_out
115
- if isinstance(pred_out, list):
116
- return np.asarray(pred_out)
117
- if isinstance(pred_out, dict):
118
- if "samples" in pred_out:
119
- return np.asarray(pred_out["samples"])
120
- # sometimes "forecast" keys etc.
121
- for k in ("predictions", "prediction", "outputs"):
122
- if k in pred_out:
123
- return np.asarray(pred_out[k])
124
- return np.asarray(pred_out)
125
- # object with samples attribute
126
- if hasattr(pred_out, "samples"):
127
- return np.asarray(getattr(pred_out, "samples"))
128
- # last resort
129
- return np.asarray(pred_out)
130
-
131
- def chronos2_predict_samples(pipe, y, prediction_length: int, n_draws: int):
132
  """
133
- Calls pipe.predict in a robust way across Chronos versions:
134
- - Uses `inputs=` (required)
135
- - Uses `num_predictions=` if supported
136
- - If not supported, falls back to a single prediction and returns shape (1, pred_len)
137
  """
138
- sig = inspect.signature(pipe.predict)
139
- params = sig.parameters
140
-
141
- kwargs = {"inputs": y.tolist(), "prediction_length": int(prediction_length)}
142
-
143
- # API differences: some versions accept num_predictions, others not
144
- if "num_predictions" in params:
145
- kwargs["num_predictions"] = int(n_draws)
146
-
147
- # Some versions might have different names; try a couple safe fallbacks
148
- try:
149
- out = pipe.predict(**kwargs)
150
- except TypeError as e:
151
- # If num_predictions was rejected, retry without it
152
- if "num_predictions" in kwargs:
153
- kwargs.pop("num_predictions", None)
154
- out = pipe.predict(**kwargs)
155
- else:
156
- raise e
157
-
158
- samples = _extract_samples(out).astype(np.float32)
159
-
160
- # Normalize shape: expected (n_draws, pred_len)
161
- if samples.ndim == 1:
162
- samples = samples[None, :]
163
- elif samples.ndim == 2:
164
- pass
165
- else:
166
- # If extra dims, squeeze conservatively
167
- samples = np.squeeze(samples)
168
- if samples.ndim == 1:
169
- samples = samples[None, :]
170
-
171
- return samples
172
 
173
  # =========================
174
  # Forecast core
@@ -185,33 +161,32 @@ def run_forecast(
185
  season_amp,
186
  noise,
187
  prediction_length,
188
- num_draws,
189
  q_low,
190
  q_high,
191
  device_ui,
192
  model_id,
193
  ):
194
- # Validate quantiles
195
- if float(q_low) >= float(q_high):
 
196
  raise gr.Error("Quantile low deve essere < quantile high.")
197
 
198
- # Device + pipeline
199
  device = pick_device(device_ui)
200
  pipe = get_pipeline(model_id, device)
201
 
202
- # Choose input series
203
  if input_mode == "Test CSV":
204
  if not test_csv_name:
205
- raise gr.Error("Seleziona un file nella dropdown dei Test CSV oppure usa Sample/Upload.")
206
- csv_path = os.path.join(DATA_DIR, test_csv_name)
207
- if not os.path.exists(csv_path):
208
- raise gr.Error(f"Non trovo {csv_path}. Assicurati che esista nel repo dello Space.")
209
- y, used_col = load_series_from_csv(csv_path, csv_column)
210
  source = f"Test CSV: {test_csv_name} ({used_col})"
211
 
212
  elif input_mode == "Upload CSV":
213
  if upload_csv is None:
214
- raise gr.Error("Carica un CSV oppure scegli Sample/Test CSV.")
215
  y, used_col = load_series_from_csv(upload_csv.name, csv_column)
216
  source = f"Upload CSV ({used_col})"
217
 
@@ -219,27 +194,41 @@ def run_forecast(
219
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
220
  source = "Sample data"
221
 
222
- # Forecast samples
223
- samples = chronos2_predict_samples(
224
- pipe=pipe,
225
- y=y,
 
 
 
226
  prediction_length=int(prediction_length),
227
- n_draws=int(num_draws),
 
 
 
228
  )
229
 
230
- # Quantiles
231
- median = np.quantile(samples, 0.50, axis=0)
232
- low = np.quantile(samples, float(q_low), axis=0)
233
- high = np.quantile(samples, float(q_high), axis=0)
 
 
 
 
 
 
 
 
234
 
235
- # Plot
236
  t_hist = np.arange(len(y))
237
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
238
 
239
  fig, ax = plt.subplots(figsize=(10, 4))
240
  ax.plot(t_hist, y, label="history")
241
  ax.plot(t_fcst, median, label="forecast (median)")
242
- ax.fill_between(t_fcst, low, high, alpha=0.25, label=f"band [{float(q_low):.2f}, {float(q_high):.2f}]")
243
  ax.axvline(len(y) - 1, linestyle="--", linewidth=1)
244
  ax.set_title(source)
245
  ax.set_xlabel("t")
@@ -247,13 +236,14 @@ def run_forecast(
247
  ax.grid(True, alpha=0.3)
248
  ax.legend()
249
 
250
- # Output table + CSV
251
  out_df = pd.DataFrame(
252
  {
253
  "t": t_fcst,
 
254
  "median": median,
255
- f"q{float(q_low):.2f}": low,
256
- f"q{float(q_high):.2f}": high,
257
  }
258
  )
259
 
@@ -266,25 +256,27 @@ def run_forecast(
266
  "source": source,
267
  "history_points": int(len(y)),
268
  "prediction_length": int(prediction_length),
269
- "requested_draws": int(num_draws),
270
- "returned_draws": int(samples.shape[0]),
271
  }
272
 
273
  return fig, out_df, out_path, info
274
 
 
275
  # =========================
276
  # UI
277
  # =========================
278
  with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
279
- gr.Markdown("# ⏱️ Chronos-2 Forecast Demo (HF Spaces)\n\n"
280
- "Supporta **Sample**, **Test CSV** (da cartella `data/`) e **Upload CSV**.")
 
 
 
 
 
281
 
282
  with gr.Row():
283
- input_mode = gr.Radio(
284
- ["Sample", "Test CSV", "Upload CSV"],
285
- value="Sample",
286
- label="Input source",
287
- )
288
  device_ui = gr.Dropdown(
289
  ["cpu", "cuda (se disponibile)"],
290
  value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
@@ -293,10 +285,7 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
293
  model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
294
 
295
  with gr.Row():
296
- test_csv_name = gr.Dropdown(
297
- choices=available_test_csv(),
298
- label="Test CSV disponibili (cartella data/)",
299
- )
300
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
301
  csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
302
 
@@ -310,8 +299,6 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
310
 
311
  with gr.Accordion("Forecast settings", open=True):
312
  prediction_length = gr.Slider(1, 180, 30, step=1, label="Prediction length")
313
- # UI label stays "Num samples", internally treated as number of prediction draws if supported
314
- num_draws = gr.Slider(1, 400, 200, step=10, label="Num samples (draws)")
315
  q_low = gr.Slider(0.01, 0.49, 0.10, step=0.01, label="Quantile low")
316
  q_high = gr.Slider(0.51, 0.99, 0.90, step=0.01, label="Quantile high")
317
 
@@ -336,7 +323,6 @@ with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
336
  season_amp,
337
  noise,
338
  prediction_length,
339
- num_draws,
340
  q_low,
341
  q_high,
342
  device_ui,
 
1
  import os
 
2
  import numpy as np
3
  import pandas as pd
4
  import gradio as gr
 
7
 
8
  from chronos import Chronos2Pipeline
9
 
10
+
11
  # =========================
12
  # Config
13
  # =========================
14
  MODEL_ID_DEFAULT = os.getenv("CHRONOS_MODEL_ID", "amazon/chronos-2")
15
  DATA_DIR = "data"
16
 
17
+
18
  # =========================
19
+ # Utils
20
  # =========================
21
  def available_test_csv():
22
  if not os.path.isdir(DATA_DIR):
23
  return []
24
  return sorted(f for f in os.listdir(DATA_DIR) if f.lower().endswith(".csv"))
25
 
26
+
27
  def pick_device(ui_choice: str) -> str:
28
  if (ui_choice or "").startswith("cuda") and torch.cuda.is_available():
29
  return "cuda"
30
  return "cpu"
31
 
 
 
 
 
 
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  def make_sample_series(n, seed, trend, season_period, season_amp, noise):
34
  rng = np.random.default_rng(int(seed))
35
  t = np.arange(int(n))
 
38
  + float(season_amp) * np.sin(2 * np.pi * t / max(1, int(season_period)))
39
  + rng.normal(0.0, float(noise), size=len(t))
40
  )
41
+ # shift up if negative to keep plots nice
42
  mn = float(np.min(y))
43
  if mn < 0:
44
  y = y - mn
45
  return y.astype(np.float32)
46
 
47
+
48
  def load_series_from_csv(path_or_file, column=None):
49
  df = pd.read_csv(path_or_file)
 
50
  if df.shape[1] == 0:
51
  raise ValueError("CSV vuoto o non leggibile.")
52
 
53
  col = (column or "").strip()
54
  if col == "":
55
+ # try native numeric dtypes first
56
  numeric_cols = [c for c in df.columns if pd.api.types.is_numeric_dtype(df[c])]
57
+ # fallback: try coercion
58
  if not numeric_cols:
 
 
59
  for c in df.columns:
60
  coerced = pd.to_numeric(df[c], errors="coerce")
61
  if coerced.notna().sum() >= 10:
62
  numeric_cols.append(c)
63
+ if not numeric_cols:
64
+ raise ValueError("Nessuna colonna numerica nel CSV. Specifica la colonna corretta.")
65
  col = numeric_cols[0]
66
 
67
  if col not in df.columns:
 
73
 
74
  return y.astype(np.float32), col
75
 
76
+
77
  # =========================
78
+ # Pipeline cache
79
  # =========================
80
+ _PIPELINE = None
81
+ _PIPELINE_META = {}
82
+
83
+
84
+ def get_pipeline(model_id: str, device: str):
85
+ global _PIPELINE, _PIPELINE_META
86
+
87
+ model_id = (model_id or MODEL_ID_DEFAULT).strip()
88
+ device = "cuda" if device == "cuda" and torch.cuda.is_available() else "cpu"
89
+
90
+ if (
91
+ _PIPELINE is None
92
+ or _PIPELINE_META.get("model_id") != model_id
93
+ or _PIPELINE_META.get("device") != device
94
+ ):
95
+ _PIPELINE = Chronos2Pipeline.from_pretrained(model_id, device_map=device)
96
+ _PIPELINE_META = {"model_id": model_id, "device": device}
97
+
98
+ return _PIPELINE
99
+
100
+
101
+ # =========================
102
+ # Chronos-2 predict_df helpers
103
+ # =========================
104
+ def build_context_df(y: np.ndarray, freq: str = "D"):
105
  """
106
+ Build a minimal context DataFrame compatible with Chronos2Pipeline.predict_df().
107
+ We generate a synthetic timestamp index so it works for Sample and numeric-only CSV.
 
 
 
108
  """
109
+ ts = pd.date_range("2000-01-01", periods=len(y), freq=freq)
110
+ return pd.DataFrame({"id": "series_0", "timestamp": ts, "target": y})
111
+
112
+
113
+ def pick_quantile_column(pred_df: pd.DataFrame, q: float) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  """
115
+ Column naming can vary. We robustly find a column representing quantile q.
116
+ Common patterns: "0.1", "0.5", "0.9" OR "q0.1" OR "quantile_0.1" etc.
 
 
117
  """
118
+ q = float(q)
119
+ # direct numeric-string match
120
+ for c in pred_df.columns:
121
+ try:
122
+ if abs(float(c) - q) < 1e-9:
123
+ return c
124
+ except Exception:
125
+ pass
126
+
127
+ # prefixed patterns
128
+ candidates = []
129
+ for c in pred_df.columns:
130
+ lc = str(c).lower()
131
+ if "quant" in lc or lc.startswith("q"):
132
+ # try to extract float from tail
133
+ for token in [lc.replace("quantile", "").replace("_", ""), lc.replace("q", "")]:
134
+ try:
135
+ if abs(float(token) - q) < 1e-9:
136
+ candidates.append(c)
137
+ except Exception:
138
+ pass
139
+
140
+ if candidates:
141
+ return candidates[0]
142
+
143
+ raise ValueError(
144
+ f"Non riesco a trovare la colonna del quantile {q}. "
145
+ f"Colonne disponibili: {list(pred_df.columns)}"
146
+ )
147
+
 
 
 
 
148
 
149
  # =========================
150
  # Forecast core
 
161
  season_amp,
162
  noise,
163
  prediction_length,
 
164
  q_low,
165
  q_high,
166
  device_ui,
167
  model_id,
168
  ):
169
+ q_low = float(q_low)
170
+ q_high = float(q_high)
171
+ if q_low >= q_high:
172
  raise gr.Error("Quantile low deve essere < quantile high.")
173
 
 
174
  device = pick_device(device_ui)
175
  pipe = get_pipeline(model_id, device)
176
 
177
+ # 1) pick data
178
  if input_mode == "Test CSV":
179
  if not test_csv_name:
180
+ raise gr.Error("Seleziona un file nella dropdown dei Test CSV.")
181
+ path = os.path.join(DATA_DIR, test_csv_name)
182
+ if not os.path.exists(path):
183
+ raise gr.Error(f"Non trovo {path}. Assicurati che sia nel repo.")
184
+ y, used_col = load_series_from_csv(path, csv_column)
185
  source = f"Test CSV: {test_csv_name} ({used_col})"
186
 
187
  elif input_mode == "Upload CSV":
188
  if upload_csv is None:
189
+ raise gr.Error("Carica un CSV per usare la modalità Upload.")
190
  y, used_col = load_series_from_csv(upload_csv.name, csv_column)
191
  source = f"Upload CSV ({used_col})"
192
 
 
194
  y = make_sample_series(n, seed, trend, season_period, season_amp, noise)
195
  source = "Sample data"
196
 
197
+ # 2) build context df (single series)
198
+ context_df = build_context_df(y, freq="D")
199
+
200
+ # 3) predict quantiles via predict_df (stable API per chronos-2)
201
+ quantiles = sorted({q_low, 0.5, q_high})
202
+ pred_df = pipe.predict_df(
203
+ context_df,
204
  prediction_length=int(prediction_length),
205
+ quantile_levels=quantiles,
206
+ id_column="id",
207
+ timestamp_column="timestamp",
208
+ target="target",
209
  )
210
 
211
+ # 4) extract arrays
212
+ col_low = pick_quantile_column(pred_df, q_low)
213
+ col_med = pick_quantile_column(pred_df, 0.5)
214
+ col_high = pick_quantile_column(pred_df, q_high)
215
+
216
+ # pred_df contains the forecast horizon rows; keep only series_0
217
+ pred_df = pred_df[pred_df["id"] == "series_0"].copy()
218
+
219
+ ts_fcst = pd.to_datetime(pred_df["timestamp"]).to_numpy()
220
+ low = pred_df[col_low].to_numpy(dtype=np.float32)
221
+ median = pred_df[col_med].to_numpy(dtype=np.float32)
222
+ high = pred_df[col_high].to_numpy(dtype=np.float32)
223
 
224
+ # 5) plot (use integer axis for simplicity)
225
  t_hist = np.arange(len(y))
226
  t_fcst = np.arange(len(y), len(y) + int(prediction_length))
227
 
228
  fig, ax = plt.subplots(figsize=(10, 4))
229
  ax.plot(t_hist, y, label="history")
230
  ax.plot(t_fcst, median, label="forecast (median)")
231
+ ax.fill_between(t_fcst, low, high, alpha=0.25, label=f"band [{q_low:.2f}, {q_high:.2f}]")
232
  ax.axvline(len(y) - 1, linestyle="--", linewidth=1)
233
  ax.set_title(source)
234
  ax.set_xlabel("t")
 
236
  ax.grid(True, alpha=0.3)
237
  ax.legend()
238
 
239
+ # 6) output table + downloadable csv
240
  out_df = pd.DataFrame(
241
  {
242
  "t": t_fcst,
243
+ "timestamp": ts_fcst,
244
  "median": median,
245
+ f"q{q_low:.2f}": low,
246
+ f"q{q_high:.2f}": high,
247
  }
248
  )
249
 
 
256
  "source": source,
257
  "history_points": int(len(y)),
258
  "prediction_length": int(prediction_length),
259
+ "quantile_levels": quantiles,
260
+ "pred_df_columns": list(out_df.columns),
261
  }
262
 
263
  return fig, out_df, out_path, info
264
 
265
+
266
  # =========================
267
  # UI
268
  # =========================
269
  with gr.Blocks(title="Chronos-2 • HF Spaces Demo") as demo:
270
+ gr.Markdown(
271
+ "# ⏱️ Chronos-2 Forecast Demo (HF Spaces)\n"
272
+ "- **Sample**: genera una serie sintetica\n"
273
+ "- **Test CSV**: usa file in `data/`\n"
274
+ "- **Upload CSV**: carica un tuo CSV\n\n"
275
+ "Questa versione usa **predict_df()** (API consigliata per Chronos-2) e calcola direttamente i **quantili**. "
276
+ )
277
 
278
  with gr.Row():
279
+ input_mode = gr.Radio(["Sample", "Test CSV", "Upload CSV"], value="Sample", label="Input source")
 
 
 
 
280
  device_ui = gr.Dropdown(
281
  ["cpu", "cuda (se disponibile)"],
282
  value="cuda (se disponibile)" if torch.cuda.is_available() else "cpu",
 
285
  model_id = gr.Textbox(value=MODEL_ID_DEFAULT, label="Model ID")
286
 
287
  with gr.Row():
288
+ test_csv_name = gr.Dropdown(choices=available_test_csv(), label="Test CSV disponibili (data/)")
 
 
 
289
  upload_csv = gr.File(label="Upload CSV", file_types=[".csv"])
290
  csv_column = gr.Textbox(label="Colonna numerica (opzionale)", placeholder="es: value")
291
 
 
299
 
300
  with gr.Accordion("Forecast settings", open=True):
301
  prediction_length = gr.Slider(1, 180, 30, step=1, label="Prediction length")
 
 
302
  q_low = gr.Slider(0.01, 0.49, 0.10, step=0.01, label="Quantile low")
303
  q_high = gr.Slider(0.51, 0.99, 0.90, step=0.01, label="Quantile high")
304
 
 
323
  season_amp,
324
  noise,
325
  prediction_length,
 
326
  q_low,
327
  q_high,
328
  device_ui,