File size: 7,989 Bytes
d00bc89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

import os
import torch
import numpy as np
import wave
import tempfile
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
from snac import SNAC
from huggingface_hub import login, upload_file, hf_hub_download, snapshot_download
from datetime import datetime, timezone, timedelta

# Login
HF_TOKEN = os.environ.get("HF_TOKEN")
if HF_TOKEN:
    login(token=HF_TOKEN)

# Configuration from environment
MODEL_REPO = os.environ.get("MODEL_REPO", "isankhaa/or-my-model")
SUBFOLDER = os.environ.get("SUBFOLDER", "epoch-34")
BASE_MODEL = os.environ.get("BASE_MODEL", "canopylabs/orpheus-tts-0.1-pretrained")
OUTPUT_REPO = os.environ.get("OUTPUT_REPO", "isankhaa/or-my-model")
SAMPLE_RATE = 24000
VOICE = "mongolian"

print(f"Model: {MODEL_REPO}/{SUBFOLDER}")

# Global variables
model = None
tokenizer = None
snac = None

def load_models():
    global model, tokenizer, snac

    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)

    print(f"Loading model from {MODEL_REPO}/{SUBFOLDER}...")
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_REPO,
        subfolder=SUBFOLDER,
        torch_dtype=torch.bfloat16,
        device_map="cpu"  # Load to CPU first, move to GPU in generate
    )
    model.eval()
    print(f"Model loaded: {model.num_parameters():,} parameters")

    print("Loading SNAC codec...")
    snac = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval()
    print("All models loaded!")

# Load models at startup
load_models()

@spaces.GPU(duration=120)
def generate_speech(text, temperature=0.7, top_p=0.9, max_tokens=4096, upload_to_hf=False):
    """Generate speech from text using ZeroGPU"""
    global model, tokenizer, snac

    if not text.strip():
        return None, "Error: Empty text"

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Move models to GPU
    model.to(device)
    snac.to(device)

    try:
        # Format prompt
        prompt = f"{VOICE}: {text}"
        text_tokens = tokenizer.encode(prompt, add_special_tokens=False)

        # Build input
        input_ids = [128259]
        input_ids.extend(text_tokens)
        input_ids.extend([128009, 128260])

        input_tensor = torch.tensor([input_ids], device=device)
        attention_mask = torch.ones_like(input_tensor)

        print(f"Input tokens: {len(input_ids)}")

        # Generate
        with torch.inference_mode():
            output = model.generate(
                input_tensor,
                attention_mask=attention_mask,
                max_new_tokens=max_tokens,
                do_sample=True,
                temperature=temperature,
                top_p=top_p,
                repetition_penalty=1.1,
                pad_token_id=128263,
                eos_token_id=128257,
            )

        # Extract audio tokens
        generated = output[0, len(input_ids):].tolist()
        audio_tokens = []
        for token_id in generated:
            if 128266 <= token_id <= 156937:
                audio_tokens.append(token_id)
            elif token_id == 128257:
                break

        print(f"Generated {len(audio_tokens)} audio tokens")

        if len(audio_tokens) < 7:
            return None, f"Error: Only generated {len(audio_tokens)} audio tokens"

        # Decode audio tokens
        snac_tokens = []
        for idx, token_id in enumerate(audio_tokens):
            layer = idx % 7
            snac_val = token_id - 128266 - (layer * 4096)
            snac_tokens.append(snac_val)

        num_frames = len(snac_tokens) // 7
        snac_tokens = snac_tokens[:num_frames * 7]

        codes_0, codes_1, codes_2 = [], [], []
        for i in range(num_frames):
            base = i * 7
            codes_0.append(snac_tokens[base])
            codes_1.append(snac_tokens[base + 1])
            codes_1.append(snac_tokens[base + 4])
            codes_2.append(snac_tokens[base + 2])
            codes_2.append(snac_tokens[base + 3])
            codes_2.append(snac_tokens[base + 5])
            codes_2.append(snac_tokens[base + 6])

        codes = [
            torch.tensor([codes_0], device=device, dtype=torch.int32),
            torch.tensor([codes_1], device=device, dtype=torch.int32),
            torch.tensor([codes_2], device=device, dtype=torch.int32),
        ]

        # Clip to valid range
        for layer_idx, c in enumerate(codes):
            codes[layer_idx] = torch.clamp(c, 0, 4095)

        # Decode
        with torch.inference_mode():
            audio = snac.decode(codes)

        audio_np = audio.squeeze().cpu().numpy()
        duration = len(audio_np) / SAMPLE_RATE

        # Save to temp file
        temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
        audio_int16 = (audio_np * 32767).astype(np.int16)
        with wave.open(temp_file.name, "w") as wav_file:
            wav_file.setnchannels(1)
            wav_file.setsampwidth(2)
            wav_file.setframerate(SAMPLE_RATE)
            wav_file.writeframes(audio_int16.tobytes())

        status = f"Success! Duration: {duration:.2f}s, Audio tokens: {len(audio_tokens)}"

        # Upload to HuggingFace if requested
        if upload_to_hf and HF_TOKEN:
            try:
                tz_mongolia = timezone(timedelta(hours=8))
                timestamp = datetime.now(tz_mongolia).strftime("%Y-%m-%d_%H-%M")
                output_file = f"{SUBFOLDER}-test-{timestamp}.wav"
                upload_path = SUBFOLDER + "/test_output/" + output_file
                upload_file(
                    path_or_fileobj=temp_file.name,
                    path_in_repo=upload_path,
                    repo_id=OUTPUT_REPO,
                    repo_type="model",
                )
                status += f"\nUploaded: https://huggingface.co/{OUTPUT_REPO}/blob/main/{upload_path}"
            except Exception as e:
                status += f"\nUpload failed: {e}"

        return temp_file.name, status

    except Exception as e:
        return None, f"Error: {str(e)}"
    finally:
        # Move back to CPU to free GPU memory
        model.to("cpu")
        snac.to("cpu")
        torch.cuda.empty_cache()

# Create Gradio interface
with gr.Blocks(title="Mongolian TTS (ZeroGPU)") as demo:
    gr.Markdown(f"""
    # 🎤 Mongolian Text-to-Speech

    Orpheus TTS model fine-tuned for Mongolian language.

    **Model:** {MODEL_REPO}/{SUBFOLDER}

    Using HuggingFace ZeroGPU (FREE!)
    """)

    with gr.Row():
        with gr.Column():
            text_input = gr.Textbox(
                label="Text (Mongolian)",
                placeholder="Энд монгол текст бичнэ үү...",
                lines=3,
            )

            with gr.Row():
                temperature = gr.Slider(0.1, 1.5, value=0.7, label="Temperature")
                top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P")

            max_tokens = gr.Slider(512, 8192, value=4096, step=512, label="Max Tokens")
            upload_checkbox = gr.Checkbox(label="Upload to HuggingFace", value=True)

            generate_btn = gr.Button("🎵 Generate Speech", variant="primary")

        with gr.Column():
            audio_output = gr.Audio(label="Generated Audio", type="filepath")
            status_output = gr.Textbox(label="Status", lines=3)

    generate_btn.click(
        fn=generate_speech,
        inputs=[text_input, temperature, top_p, max_tokens, upload_checkbox],
        outputs=[audio_output, status_output],
    )

    gr.Examples(
        examples=[
            ["Сайн байна уу, энэ бол монгол хэлний туршилт юм."],
            ["Өнөөдөр цаг агаар сайхан байна."],
            ["Дэд бүтэц, нийгмийн үйлчилгээний хүрээнд ч томоохон бүтээн байгуулалтууд хийгдэх юм."],
        ],
        inputs=[text_input],
    )

demo.launch()