Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preloading audio for Edge TTS and other fixes #138

Merged
merged 4 commits into from
Jan 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/readest-app/src-tauri/tauri.conf.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"security": {
"csp": {
"default-src": "'self' 'unsafe-inline' blob: customprotocol: asset: http://asset.localhost ipc: http://ipc.localhost https://fonts.gstatic.com https://db.onlinewebfonts.com",
"connect-src": "'self' blob: asset: http://asset.localhost ipc: http://ipc.localhost https://*.sentry.io https://*.posthog.com https://*.deepl.com https://*.wikipedia.org https://*.wiktionary.org https://*.supabase.co https://*.readest.com",
"connect-src": "'self' blob: asset: http://asset.localhost ipc: http://ipc.localhost https://*.sentry.io https://*.posthog.com https://*.deepl.com https://*.wikipedia.org https://*.wiktionary.org https://*.supabase.co https://*.readest.com wss://speech.platform.bing.com",
"img-src": "'self' blob: data: asset: http://asset.localhost https://*",
"style-src": "'self' 'unsafe-inline' blob: asset: http://asset.localhost",
"frame-src": "'self' blob: asset: http://asset.localhost",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ const FontDropdown: React.FC<DropdownProps> = ({
>
{moreOptions.map((option) => (
<li key={option} onClick={() => onSelect(option)}>
<div className='flex items-center px-0'>
<div className='flex items-center px-2'>
<span style={{ minWidth: '20px' }}>
{selected === option && <MdCheck size={20} className='text-base-content' />}
</span>
Expand Down
16 changes: 16 additions & 0 deletions apps/readest-app/src/app/reader/components/tts/TTSControl.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ const TTSControl = () => {

setBookKey(bookKey);

if (ttsControllerRef.current) {
ttsControllerRef.current.stop();
ttsControllerRef.current = null;
}

try {
const ttsController = new TTSController(view);
await ttsController.init();
Expand Down Expand Up @@ -94,6 +99,8 @@ const TTSControl = () => {
setIsPlaying(false);
setIsPaused(true);
} else if (isPaused) {
// start for forward/backward/setvoice-paused
// set rate don't pause the tts
if (ttsController.state === 'paused') {
ttsController.resume();
} else {
Expand Down Expand Up @@ -163,6 +170,14 @@ const TTSControl = () => {
return [];
};

const handleGetVoiceId = () => {
const ttsController = ttsControllerRef.current;
if (ttsController) {
return ttsController.getVoiceId();
}
return '';
};

const updatePanelPosition = () => {
if (iconRef.current) {
const rect = iconRef.current.getBoundingClientRect();
Expand Down Expand Up @@ -228,6 +243,7 @@ const TTSControl = () => {
onSetRate={handleSetRate}
onGetVoices={handleGetVoices}
onSetVoice={handleSetVoice}
onGetVoiceId={handleGetVoiceId}
/>
</Popup>
)}
Expand Down
10 changes: 9 additions & 1 deletion apps/readest-app/src/app/reader/components/tts/TTSPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ type TTSPanelProps = {
onSetRate: (rate: number) => void;
onGetVoices: (lang: string) => Promise<TTSVoice[]>;
onSetVoice: (voice: string) => void;
onGetVoiceId: () => string;
};

const TTSPanel = ({
Expand All @@ -31,6 +32,7 @@ const TTSPanel = ({
onSetRate,
onGetVoices,
onSetVoice,
onGetVoiceId,
}: TTSPanelProps) => {
const _ = useTranslation();
const { getViewSettings, setViewSettings } = useReaderStore();
Expand Down Expand Up @@ -58,6 +60,12 @@ const TTSPanel = ({
setViewSettings(bookKey, viewSettings);
};

useEffect(() => {
const voiceId = onGetVoiceId();
setSelectedVoice(voiceId);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
const fetchVoices = async () => {
const voices = await onGetVoices(ttsLang);
Expand Down Expand Up @@ -131,7 +139,7 @@ const TTSPanel = ({
key={`${index}-${voice.id}`}
onClick={() => !voice.disabled && handleSelectVoice(voice.id)}
>
<div className='flex items-center px-0'>
<div className='flex items-center px-2'>
<span style={{ minWidth: '20px' }}>
{selectedVoice === voice.id && (
<MdCheck size={20} className='text-base-content' />
Expand Down
35 changes: 27 additions & 8 deletions apps/readest-app/src/libs/edgeTTS.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import { md5 } from 'js-md5';
import { randomMd5 } from '@/utils/misc';
import { LRUCache } from '@/utils/lru';

const EDGE_SPEECH_URL =
'wss://speech.platform.bing.com/consumer/speech/synthesize/readaloud/edge/v1';
Expand Down Expand Up @@ -46,18 +48,25 @@ const genVoiceList = (voices: Record<string, string[]>) => {
};

export interface EdgeTTSPayload {
lang: string;
text: string;
voice: string;
rate: number;
pitch: number;
}

const hashPayload = (payload: EdgeTTSPayload): string => {
const base = JSON.stringify(payload);
return md5(base);
};

export class EdgeSpeechTTS {
static voices = genVoiceList(EDGE_TTS_VOICES);
private static audioCache = new LRUCache<string, AudioBuffer>(200);

constructor() {}

async #fetchEdgeSpeechWs({ text, voice, rate }: EdgeTTSPayload): Promise<Response> {
async #fetchEdgeSpeechWs({ lang, text, voice, rate }: EdgeTTSPayload): Promise<Response> {
const connectId = randomMd5();
const url = `${EDGE_SPEECH_URL}?ConnectionId=${connectId}&TrustedClientToken=${EDGE_API_TOKEN}`;
const date = new Date().toString();
Expand All @@ -83,9 +92,9 @@ export class EdgeSpeechTTS {
},
});

const genSSML = (text: string, voice: string, rate: number) => {
const genSSML = (lang: string, text: string, voice: string, rate: number) => {
return `
<speak version="1.0" xml:lang="en-US">
<speak version="1.0" xml:lang="${lang}">
<voice name="${voice}">
<prosody rate="${rate}">
${text}
Expand Down Expand Up @@ -126,7 +135,7 @@ export class EdgeSpeechTTS {
return { headers, body };
};

const ssml = genSSML(text, voice, rate);
const ssml = genSSML(lang, text, voice, rate);
const content = genSendContent(contentHeaders, ssml);
const config = genSendContent(configHeaders, configContent);

Expand Down Expand Up @@ -177,9 +186,19 @@ export class EdgeSpeechTTS {
}

async createAudio(payload: EdgeTTSPayload): Promise<AudioBuffer> {
const res = await this.create(payload);
const arrayBuffer = await res.arrayBuffer();
const audioContext = new AudioContext();
return await audioContext.decodeAudioData(arrayBuffer.slice(0));
const cacheKey = hashPayload(payload);
if (EdgeSpeechTTS.audioCache.has(cacheKey)) {
return EdgeSpeechTTS.audioCache.get(cacheKey)!;
}
try {
const res = await this.create(payload);
const arrayBuffer = await res.arrayBuffer();
const audioContext = new AudioContext();
const audioBuffer = await audioContext.decodeAudioData(arrayBuffer.slice(0));
EdgeSpeechTTS.audioCache.set(cacheKey, audioBuffer);
return audioBuffer;
} catch (error) {
throw error;
}
}
}
49 changes: 35 additions & 14 deletions apps/readest-app/src/services/tts/EdgeTTSClient.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { getUserLocale } from '@/utils/misc';
import { TTSClient, TTSMessageEvent, TTSVoice } from './TTSClient';
import { EdgeSpeechTTS } from '@/libs/edgeTTS';
import { EdgeSpeechTTS, EdgeTTSPayload } from '@/libs/edgeTTS';
import { parseSSMLLang, parseSSMLMarks } from '@/utils/ssml';
import { TTSGranularity } from '@/types/view';

Expand All @@ -27,6 +27,7 @@ export class EdgeTTSClient implements TTSClient {
this.#voices = EdgeSpeechTTS.voices;
try {
await this.#edgeTTS.create({
lang: 'en',
text: 'test',
voice: 'en-US-AriaNeural',
rate: 1.0,
Expand All @@ -39,29 +40,39 @@ export class EdgeTTSClient implements TTSClient {
return this.available;
}

getPayload = (lang: string, text: string, voiceId: string) => {
return { lang, text, voice: voiceId, rate: this.#rate, pitch: this.#pitch } as EdgeTTSPayload;
};

async *speak(ssml: string): AsyncGenerator<TTSMessageEvent> {
const { marks } = parseSSMLMarks(ssml);
const lang = parseSSMLLang(ssml) || 'en';

let voiceId = 'en-US-AriaNeural';
if (!this.#voice) {
const voices = await this.getVoices(lang);
this.#voice = voices[0] ? voices[0] : this.#voices.find((v) => v.id === voiceId) || null;
}
if (this.#voice) {
voiceId = this.#voice.id;
} else {
const voices = await this.getVoices(lang);
voiceId = voices[0]?.id || voiceId;
}

this.stopInternal();

for (const mark of marks) {
try {
this.#audioBuffer = await this.#edgeTTS.createAudio({
text: mark.text.replace(/\r?\n/g, ''),
voice: voiceId,
rate: this.#rate,
pitch: this.#pitch,
// Preloading for longer ssml
if (marks.length > 1) {
for (const mark of marks.slice(1)) {
this.#edgeTTS.createAudio(this.getPayload(lang, mark.text, voiceId)).catch((error) => {
console.warn('Error preloading mark:', mark, error);
});
}
}

for (const mark of marks) {
try {
this.#audioBuffer = await this.#edgeTTS.createAudio(
this.getPayload(lang, mark.text, voiceId),
);
this.#audioContext = new AudioContext();
this.#sourceNode = this.#audioContext.createBufferSource();
this.#sourceNode.buffer = this.#audioBuffer;
Expand Down Expand Up @@ -89,10 +100,16 @@ export class EdgeTTSClient implements TTSClient {
this.#startedAt = this.#audioContext.currentTime;
});
yield result;
if (result.code === 'error') {
break;
}
} catch (error) {
if (error instanceof Error && error.message === 'No audio data received.') {
console.warn('No audio data received for:', mark.text);
yield {
code: 'end',
message: `Chunk finished: ${mark.name}`,
};
continue;
}
console.log('Error:', error);
yield {
code: 'error',
message: error instanceof Error ? error.message : String(error),
Expand Down Expand Up @@ -175,4 +192,8 @@ export class EdgeTTSClient implements TTSClient {
getGranularities(): TTSGranularity[] {
return ['sentence'];
}

getVoiceId(): string {
return this.#voice?.id || '';
}
}
1 change: 1 addition & 0 deletions apps/readest-app/src/services/tts/TTSClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ export interface TTSClient {
getAllVoices(): Promise<TTSVoice[]>;
getVoices(lang: string): Promise<TTSVoice[]>;
getGranularities(): TTSGranularity[];
getVoiceId(): string;
}
17 changes: 14 additions & 3 deletions apps/readest-app/src/services/tts/TTSController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,13 @@ import { TTSClient, TTSMessageCode, TTSVoice } from './TTSClient';
import { WebSpeechClient } from './WebSpeechClient';
import { EdgeTTSClient } from './EdgeTTSClient';

type TTSState = 'stopped' | 'playing' | 'paused' | 'backward-paused' | 'forward-paused';
type TTSState =
| 'stopped'
| 'playing'
| 'paused'
| 'backward-paused'
| 'forward-paused'
| 'setvoice-paused';

export class TTSController extends EventTarget {
state: TTSState = 'stopped';
Expand Down Expand Up @@ -53,7 +59,7 @@ export class TTSController extends EventTarget {
if (!ssml) {
this.#nossmlCnt++;
// FIXME: in case we are at the end of the book, need a better way to handle this
if (this.#nossmlCnt < 10) {
if (this.#nossmlCnt < 10 && this.state === 'playing') {
await this.view.next(1);
this.forward();
}
Expand All @@ -71,7 +77,7 @@ export class TTSController extends EventTarget {
lastCode = code;
}

if (lastCode === 'end') {
if (lastCode === 'end' && this.state === 'playing') {
this.forward();
}
}
Expand Down Expand Up @@ -147,6 +153,7 @@ export class TTSController extends EventTarget {
}

async setVoice(voiceId: string) {
this.state = 'setvoice-paused';
this.ttsClient.stop();
if (this.ttsEdgeVoices.find((voice) => voice.id === voiceId && !voice.disabled)) {
this.ttsClient = this.ttsEdgeClient;
Expand All @@ -156,6 +163,10 @@ export class TTSController extends EventTarget {
await this.ttsClient.setVoice(voiceId);
}

getVoiceId() {
return this.ttsClient.getVoiceId();
}

error(e: unknown) {
console.error(e);
this.state = 'stopped';
Expand Down
Loading
Loading