Spaces:
Runtime error
Runtime error
add loose threshold/ remove speed limitation
Browse files- .gitignore +2 -1
- app.py +22 -6
- config/Arthur.yaml +2 -5
- local/check_data.py +15 -1
- local/indicator_plot.py +97 -0
.gitignore
CHANGED
|
@@ -540,4 +540,5 @@ user/
|
|
| 540 |
|
| 541 |
.vscode
|
| 542 |
|
| 543 |
-
!data/Patient_sil_trim_16k_normed_5_snr_40/*
|
|
|
|
|
|
| 540 |
|
| 541 |
.vscode
|
| 542 |
|
| 543 |
+
!data/Patient_sil_trim_16k_normed_5_snr_40/*
|
| 544 |
+
downloads
|
app.py
CHANGED
|
@@ -178,6 +178,13 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
| 178 |
truth_transform=transformation,
|
| 179 |
hypothesis_transform=transformation,
|
| 180 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
# MOS
|
| 182 |
batch = {
|
| 183 |
"wav": out_wavs,
|
|
@@ -187,7 +194,12 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
| 187 |
with torch.no_grad():
|
| 188 |
output = model(batch)
|
| 189 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
# Phonemes per minute (PPM)
|
| 192 |
with torch.no_grad():
|
| 193 |
logits = phoneme_model(out_wavs).logits
|
|
@@ -204,6 +216,10 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
| 204 |
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
| 205 |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
error_msg = "!!! ERROR MESSAGE !!!\n"
|
| 208 |
if audio_path == _ or audio_path == None:
|
| 209 |
error_msg += "ERROR: Fail recording, Please start from the beginning again."
|
|
@@ -216,11 +232,11 @@ def calc_mos(_, audio_path, id, ref, pre_ppm, fig=None):
|
|
| 216 |
ppm,
|
| 217 |
error_msg,
|
| 218 |
)
|
| 219 |
-
if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
|
| 220 |
-
|
| 221 |
-
elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
|
| 222 |
-
|
| 223 |
-
|
| 224 |
error_msg += "ERROR: Naturalness is too low, Please try again.\n"
|
| 225 |
elif wer >= float(config["thre"]["WER"]):
|
| 226 |
error_msg += "ERROR: Intelligibility is too low, Please try again\n"
|
|
|
|
| 178 |
truth_transform=transformation,
|
| 179 |
hypothesis_transform=transformation,
|
| 180 |
)
|
| 181 |
+
|
| 182 |
+
# round to 1 decimal
|
| 183 |
+
wer = np.round(wer, 1)
|
| 184 |
+
|
| 185 |
+
# WER convert to Intellibility score
|
| 186 |
+
INTELI_score = WER2INTELI(wer*100)
|
| 187 |
+
|
| 188 |
# MOS
|
| 189 |
batch = {
|
| 190 |
"wav": out_wavs,
|
|
|
|
| 194 |
with torch.no_grad():
|
| 195 |
output = model(batch)
|
| 196 |
predic_mos = output.mean(dim=1).squeeze().detach().numpy() * 2 + 3
|
| 197 |
+
|
| 198 |
+
# round to 1 decimal
|
| 199 |
+
predic_mos = np.round(predic_mos, 1)
|
| 200 |
+
|
| 201 |
+
# MOS to AVA MOS
|
| 202 |
+
AVA_MOS = nat2avaMOS(predic_mos)
|
| 203 |
# Phonemes per minute (PPM)
|
| 204 |
with torch.no_grad():
|
| 205 |
logits = phoneme_model(out_wavs).logits
|
|
|
|
| 216 |
fig_h = plot_UV(wav_vad.numpy().squeeze(), a_h, sr=sr)
|
| 217 |
ppm = len(lst_phonemes) / (wav_vad.shape[-1] / sr) * 60
|
| 218 |
|
| 219 |
+
|
| 220 |
+
ppm = np.round(ppm, 1)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
error_msg = "!!! ERROR MESSAGE !!!\n"
|
| 224 |
if audio_path == _ or audio_path == None:
|
| 225 |
error_msg += "ERROR: Fail recording, Please start from the beginning again."
|
|
|
|
| 232 |
ppm,
|
| 233 |
error_msg,
|
| 234 |
)
|
| 235 |
+
# if ppm >= float(pre_ppm) + float(config["thre"]["maxppm"]):
|
| 236 |
+
# error_msg += "ERROR: Please speak slower.\n"
|
| 237 |
+
# elif ppm <= float(pre_ppm) - float(config["thre"]["minppm"]):
|
| 238 |
+
# error_msg += "ERROR: Please speak faster.\n"
|
| 239 |
+
if predic_mos <= float(config["thre"]["AUTOMOS"]):
|
| 240 |
error_msg += "ERROR: Naturalness is too low, Please try again.\n"
|
| 241 |
elif wer >= float(config["thre"]["WER"]):
|
| 242 |
error_msg += "ERROR: Intelligibility is too low, Please try again\n"
|
config/Arthur.yaml
CHANGED
|
@@ -3,10 +3,7 @@ ref_txt: data/Arthur_the_rat.txt
|
|
| 3 |
ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
|
| 4 |
ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
|
| 5 |
thre:
|
| 6 |
-
minppm:
|
| 7 |
-
maxppm:
|
| 8 |
WER: 0.5
|
| 9 |
AUTOMOS: 2.0
|
| 10 |
-
auth:
|
| 11 |
-
username: Kath
|
| 12 |
-
password: Kath
|
|
|
|
| 3 |
ref_feature: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat.csv
|
| 4 |
ref_wavs: data/Patient_sil_trim_16k_normed_5_snr_40/Arthur_the_rat
|
| 5 |
thre:
|
| 6 |
+
minppm: 0
|
| 7 |
+
maxppm: 1000
|
| 8 |
WER: 0.5
|
| 9 |
AUTOMOS: 2.0
|
|
|
|
|
|
|
|
|
local/check_data.py
CHANGED
|
@@ -27,9 +27,23 @@ import io
|
|
| 27 |
import sys
|
| 28 |
|
| 29 |
file_id = sys.argv[1]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
|
| 31 |
# Get the file's metadata
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
request = service.files().get_media(fileId=file_id)
|
| 35 |
with open(file['name'], 'wb') as file_obj:
|
|
|
|
| 27 |
import sys
|
| 28 |
|
| 29 |
file_id = sys.argv[1]
|
| 30 |
+
if file_id == "all":
|
| 31 |
+
results = service.files().list().execute()
|
| 32 |
+
files = results.get('files', [])
|
| 33 |
+
# download all files
|
| 34 |
+
for file in files:
|
| 35 |
+
request = service.files().get_media(fileId=file['id'])
|
| 36 |
+
with open("download/" + file['name'], 'wb') as file_obj:
|
| 37 |
+
downloader = MediaIoBaseDownload(file_obj, request)
|
| 38 |
+
done = False
|
| 39 |
+
while not done:
|
| 40 |
+
status, done = downloader.next_chunk()
|
| 41 |
+
print(f"Download {int(status.progress() * 100)}%.")
|
| 42 |
+
|
| 43 |
# "1YjON2ObGM826KaaqF-sKM7CO0tAtzWGg"
|
| 44 |
# Get the file's metadata
|
| 45 |
+
else:
|
| 46 |
+
file = service.files().get(fileId=file_id).execute()
|
| 47 |
|
| 48 |
request = service.files().get_media(fileId=file_id)
|
| 49 |
with open(file['name'], 'wb') as file_obj:
|
local/indicator_plot.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import plotly.graph_objects as go
|
| 2 |
+
|
| 3 |
+
def Intelligibility_Plot(Int_Score, fair_thre=30, good_thre = 70, Upper=100, Lower=0):
|
| 4 |
+
'''
|
| 5 |
+
Int_Score: a float number between 0 and 100
|
| 6 |
+
Upper: the upper bound of the plot
|
| 7 |
+
Lower: the lower bound of the plot
|
| 8 |
+
'''
|
| 9 |
+
# Assert Nat_Score is a float number between 0 and 100
|
| 10 |
+
assert isinstance(Int_Score, float|int)
|
| 11 |
+
assert Int_Score >= Lower
|
| 12 |
+
assert Int_Score <= Upper
|
| 13 |
+
# Indicator plot with different colors, under fair_threshold the plot is red, then yellow, then green
|
| 14 |
+
# Design 1: Show bar in different colors refer to the threshold
|
| 15 |
+
|
| 16 |
+
color = "#75DA99"
|
| 17 |
+
if Int_Score <= fair_thre:
|
| 18 |
+
color = "#F2ADA0"
|
| 19 |
+
elif Int_Score <= good_thre:
|
| 20 |
+
color = "#e8ee89"
|
| 21 |
+
else:
|
| 22 |
+
color = "#75DA99"
|
| 23 |
+
|
| 24 |
+
fig = go.Figure(go.Indicator(
|
| 25 |
+
mode="number+gauge",
|
| 26 |
+
gauge={'shape': "bullet",
|
| 27 |
+
'axis':{'range': [Lower, Upper]},
|
| 28 |
+
'bgcolor': 'white',
|
| 29 |
+
'bar': {'color': color},
|
| 30 |
+
},
|
| 31 |
+
value=Int_Score,
|
| 32 |
+
domain = {'x': [0, 1], 'y': [0, 1]},
|
| 33 |
+
)
|
| 34 |
+
)
|
| 35 |
+
# # Design 2: Show all thresholds in the background
|
| 36 |
+
# fig = go.Figure(go.Indicator(
|
| 37 |
+
# mode = "number+gauge",
|
| 38 |
+
# gauge = {'shape': "bullet",
|
| 39 |
+
# 'axis': {'range': [Lower, Upper]},
|
| 40 |
+
# 'bgcolor': 'white',
|
| 41 |
+
# 'steps': [
|
| 42 |
+
# {'range': [Lower, fair_thre], 'color': "#F2ADA0"},
|
| 43 |
+
# {'range': [fair_thre, good_thre], 'color': "#e8ee89"},
|
| 44 |
+
# {'range': [good_thre, Upper], 'color': " #75DA99"}],
|
| 45 |
+
# 'bar': {'color': "grey"},
|
| 46 |
+
# },
|
| 47 |
+
# value = Int_Score,
|
| 48 |
+
# domain = {'x': [0, 1], 'y': [0, 1]},
|
| 49 |
+
# )
|
| 50 |
+
# )
|
| 51 |
+
fig.update_layout(height=300, width=1000)
|
| 52 |
+
return fig
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def Naturalness_Plot(Nat_Score, fair_thre=2, good_thre = 4, Upper=5, Lower=1.0):
|
| 56 |
+
'''
|
| 57 |
+
Int_Score: a float number between 0 and 100
|
| 58 |
+
Upper: the upper bound of the plot
|
| 59 |
+
Lower: the lower bound of the plot
|
| 60 |
+
'''
|
| 61 |
+
# Assert Nat_Score is a float number between 0 and 100
|
| 62 |
+
assert isinstance(Nat_Score, float|int)
|
| 63 |
+
assert Nat_Score >= Lower
|
| 64 |
+
assert Nat_Score <= Upper
|
| 65 |
+
|
| 66 |
+
# Indicator plot with different colors, under fair_threshold the plot is red, then yellow, then green
|
| 67 |
+
|
| 68 |
+
color = "#75DA99"
|
| 69 |
+
if Nat_Score <= fair_thre:
|
| 70 |
+
color = "#F2ADA0"
|
| 71 |
+
elif Nat_Score <= good_thre:
|
| 72 |
+
color = "#e8ee89"
|
| 73 |
+
else:
|
| 74 |
+
color = "#75DA99"
|
| 75 |
+
|
| 76 |
+
fig = go.Figure(go.Indicator(
|
| 77 |
+
mode="number+gauge",
|
| 78 |
+
gauge={'shape': "bullet",
|
| 79 |
+
'axis':{'range': [Lower, Upper]},
|
| 80 |
+
'bgcolor': 'white',
|
| 81 |
+
'bar': {'color': color},
|
| 82 |
+
},
|
| 83 |
+
value=Nat_Score,
|
| 84 |
+
domain = {'x': [0, 1], 'y': [0, 1]},
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
fig.update_layout(height=300, width=1000)
|
| 89 |
+
return fig
|
| 90 |
+
|
| 91 |
+
# test case Intelligibility_Plot
|
| 92 |
+
x = Intelligibility_Plot(10)
|
| 93 |
+
x.show()
|
| 94 |
+
x = Intelligibility_Plot(50)
|
| 95 |
+
x.show()
|
| 96 |
+
x = Intelligibility_Plot(90)
|
| 97 |
+
x.show()
|