whisper.js 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. /* eslint-disable camelcase */
  2. import { pipeline, env } from "@xenova/transformers";
  3. // Disable local models
  4. env.allowLocalModels = false;
  5. // Define model factories
  6. // Ensures only one model is created of each type
  7. class PipelineFactory {
  8. static task = null;
  9. static model = null;
  10. static quantized = null;
  11. static instance = null;
  12. constructor(tokenizer, model, quantized) {
  13. this.tokenizer = tokenizer;
  14. this.model = model;
  15. this.quantized = quantized;
  16. }
  17. static async getInstance(progress_callback = null) {
  18. if (this.instance === null) {
  19. this.instance = pipeline(this.task, this.model, {
  20. quantized: this.quantized,
  21. progress_callback,
  22. });
  23. }
  24. return this.instance;
  25. }
  26. }
  27. self.addEventListener("message", async (event) => {
  28. const message = event.data;
  29. // Do some work...
  30. // TODO use message data
  31. let transcript = await transcribe(message.audio);
  32. if (transcript === null) return;
  33. // Send the result back to the main thread
  34. self.postMessage({
  35. status: "complete",
  36. task: "automatic-speech-recognition",
  37. data: transcript,
  38. });
  39. });
  40. class AutomaticSpeechRecognitionPipelineFactory extends PipelineFactory {
  41. static task = "automatic-speech-recognition";
  42. // TODO load this from config
  43. static model = "Xenova/whisper-tiny.en";
  44. // static model = "distil-whisper/distil-medium.en";
  45. static quantized = true;
  46. }
  47. const transcribe = async (audio) => {
  48. // TODO use subtask and language
  49. // TODO load from config
  50. const p = AutomaticSpeechRecognitionPipelineFactory;
  51. /*
  52. * TODO invalidate model if different
  53. * check p.model !== modelName || p.quantized !== quantized) {
  54. // Invalidate model if different
  55. p.model = modelName;
  56. p.quantized = quantized;
  57. if (p.instance !== null) {
  58. (await p.getInstance()).dispose();
  59. p.instance = null;
  60. }
  61. }
  62. */
  63. // Load transcriber model
  64. let transcriber = await p.getInstance((data) => {
  65. self.postMessage(data);
  66. });
  67. const time_precision =
  68. transcriber.processor.feature_extractor.config.chunk_length /
  69. transcriber.model.config.max_source_positions;
  70. // Storage for chunks to be processed. Initialise with an empty chunk.
  71. let chunks_to_process = [
  72. {
  73. tokens: [],
  74. finalised: false,
  75. },
  76. ];
  77. // TODO: Storage for fully-processed and merged chunks
  78. // let decoded_chunks = [];
  79. function chunk_callback(chunk) {
  80. let last = chunks_to_process[chunks_to_process.length - 1];
  81. // Overwrite last chunk with new info
  82. Object.assign(last, chunk);
  83. last.finalised = true;
  84. // Create an empty chunk after, if it not the last chunk
  85. if (!chunk.is_last) {
  86. chunks_to_process.push({
  87. tokens: [],
  88. finalised: false,
  89. });
  90. }
  91. }
  92. // Inject custom callback function to handle merging of chunks
  93. function callback_function(item) {
  94. let last = chunks_to_process[chunks_to_process.length - 1];
  95. // Update tokens of last chunk
  96. last.tokens = [...item[0].output_token_ids];
  97. // Merge text chunks
  98. // TODO optimise so we don't have to decode all chunks every time
  99. let data = transcriber.tokenizer._decode_asr(chunks_to_process, {
  100. time_precision: time_precision,
  101. return_timestamps: true,
  102. force_full_sequences: false,
  103. });
  104. self.postMessage({
  105. status: "update",
  106. task: "automatic-speech-recognition",
  107. data: data,
  108. });
  109. }
  110. // Actually run transcription
  111. let output = await transcriber(audio, {
  112. // Greedy
  113. top_k: 0,
  114. do_sample: false,
  115. // Sliding window
  116. chunk_length_s: 30,
  117. stride_length_s: 5,
  118. // Language and task
  119. language: null,
  120. task: null,
  121. // Return timestamps
  122. return_timestamps: true,
  123. force_full_sequences: false,
  124. // Callback functions
  125. callback_function: callback_function, // after each generation step
  126. chunk_callback: chunk_callback, // after each chunk is processed
  127. }).catch((error) => {
  128. self.postMessage({
  129. status: "error",
  130. task: "automatic-speech-recognition",
  131. data: error,
  132. });
  133. return null;
  134. });
  135. return output;
  136. };