speecht5.js 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990
  1. /* eslint-disable camelcase */
  2. import { pipeline, env } from "@xenova/transformers";
  3. // Disable local models
  4. env.allowLocalModels = false;
  5. // Define model factories
  6. // Ensures only one model is created of each type
  7. class PipelineFactory {
  8. static task = null;
  9. static model = null;
  10. static quantized = null;
  11. static instance = null;
  12. constructor(tokenizer, model, quantized) {
  13. this.tokenizer = tokenizer;
  14. this.model = model;
  15. this.quantized = quantized;
  16. }
  17. static async getInstance(progress_callback = null) {
  18. if (this.instance === null) {
  19. this.instance = pipeline(this.task, this.model, {
  20. quantized: this.quantized,
  21. progress_callback,
  22. });
  23. }
  24. return this.instance;
  25. }
  26. }
  27. self.addEventListener("message", async (event) => {
  28. const message = event.data;
  29. // Do some work...
  30. // 'cmu_us_slt_arctic-wav-arctic_a0001.bin'
  31. // 'speaker_embeddings.bin',
  32. let audio = await tts(
  33. message.text,
  34. message.speaker_embeddings,
  35. );
  36. if (audio === null) return;
  37. // Send the result back to the main thread
  38. self.postMessage({
  39. status: "complete",
  40. task: "text-to-speech",
  41. data: audio,
  42. });
  43. });
  44. class TextToSpeechPipelineFactory extends PipelineFactory {
  45. static task = "text-to-speech";
  46. static model = "Xenova/speecht5_tts";
  47. static quantized = null;
  48. }
  49. const tts = async (text, speaker_embeddings) => {
  50. const p = TextToSpeechPipelineFactory;
  51. // Load tts model
  52. let m = await p.getInstance((data) => {
  53. self.postMessage(data);
  54. });
  55. // Inject custom callback function to handle merging of chunks
  56. function callback_function(item) {
  57. self.postMessage({
  58. status: "update",
  59. task: "text-to-speech",
  60. data: data,
  61. });
  62. }
  63. // Actually run tts
  64. let output = await m(text, {
  65. speaker_embeddings,
  66. callback_function: callback_function, // after each generation step
  67. }).catch((error) => {
  68. self.postMessage({
  69. status: "error",
  70. task: "text-to-speech",
  71. data: error,
  72. });
  73. return null;
  74. });
  75. return output;
  76. };