1
0

koboldAiChat.ts 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import { ChatbotBackend } from "@/types/backend";
  2. import { Message } from "./messages";
  3. import { buildPrompt } from "@/utils/buildPrompt";
  4. interface KoboldAIParams {
  5. name: string,
  6. system_prompt: string,
  7. koboldai_stop_sequence: string,
  8. koboldai_url: string
  9. }
  10. export async function getKoboldAiChatResponseStream(name: string, system_prompt: string, config: ChatbotBackend["koboldai"],messages: Message[]) {
  11. if (config?.koboldai_use_extra === 'true') {
  12. return getExtra({name: name, system_prompt: system_prompt, koboldai_stop_sequence: config.koboldai_stop_sequence, koboldai_url: config.koboldai_url}, messages);
  13. } else {
  14. return getNormal({name: name, system_prompt: system_prompt, koboldai_stop_sequence: config?.koboldai_stop_sequence!, koboldai_url: config?.koboldai_url!},messages);
  15. }
  16. }
  17. // koboldcpp / stream support
  18. async function getExtra(config: KoboldAIParams, messages: Message[]) {
  19. const headers: Record<string, string> = {
  20. "Content-Type": "application/json",
  21. };
  22. const prompt = buildPrompt({name: config.name!, system_prompt: config.system_prompt!},messages);
  23. const stop_sequence: string[] = [`${config.name}:`, ...`${config.koboldai_stop_sequence}`.split("||")];
  24. const res = await fetch(`${config.koboldai_url}/api/extra/generate/stream`, {
  25. headers: headers,
  26. method: "POST",
  27. body: JSON.stringify({
  28. prompt,
  29. stop_sequence
  30. }),
  31. });
  32. const reader = res.body?.getReader();
  33. if (res.status !== 200 || ! reader) {
  34. throw new Error(`KoboldAi chat error (${res.status})`);
  35. }
  36. const stream = new ReadableStream({
  37. async start(controller: ReadableStreamDefaultController) {
  38. const decoder = new TextDecoder("utf-8");
  39. try {
  40. let buffer = "";
  41. while (true) {
  42. const { done, value } = await reader.read();
  43. if (done) break;
  44. buffer += decoder.decode(value);
  45. let eolIndex;
  46. while ((eolIndex = buffer.indexOf('\n')) >= 0) {
  47. const line = buffer.substring(0, eolIndex).trim();
  48. buffer = buffer.substring(eolIndex + 1);
  49. if (line.startsWith('data:')) {
  50. try {
  51. const json = JSON.parse(line.substring(5));
  52. const messagePiece = json.token;
  53. if (messagePiece) {
  54. controller.enqueue(messagePiece);
  55. }
  56. } catch (error) {
  57. console.error("JSON parsing error:", error, "in line:", line);
  58. }
  59. }
  60. }
  61. }
  62. } catch (error) {
  63. console.error("Stream error:", error);
  64. controller.error(error);
  65. } finally {
  66. reader.releaseLock();
  67. controller.close();
  68. }
  69. },
  70. async cancel() {
  71. await reader?.cancel();
  72. reader.releaseLock();
  73. }
  74. });
  75. return stream;
  76. }
  77. // koboldai / no stream support
  78. async function getNormal(config: KoboldAIParams,messages: Message[]) {
  79. const headers: Record<string, string> = {
  80. "Content-Type": "application/json",
  81. };
  82. const prompt = buildPrompt({name: config.name!, system_prompt: config.system_prompt!},messages);
  83. const stop_sequence: string[] = [`${config.name}:`, ...`${config.koboldai_stop_sequence}`.split("||")];
  84. const res = await fetch(`${config.koboldai_url}/api/v1/generate`, {
  85. headers: headers,
  86. method: "POST",
  87. body: JSON.stringify({
  88. prompt,
  89. stop_sequence
  90. }),
  91. });
  92. const json = await res.json();
  93. if (json.results.length === 0) {
  94. throw new Error(`KoboldAi result length 0`);
  95. }
  96. const text = json.results.map((row: {text: string}) => row.text).join('');
  97. const stream = new ReadableStream({
  98. async start(controller: ReadableStreamDefaultController) {
  99. try {
  100. text.split(' ').map((word: string) => word + ' ').forEach((word: string) => {
  101. controller.enqueue(word);
  102. });
  103. } catch (error) {
  104. controller.error(error);
  105. } finally {
  106. controller.close();
  107. }
  108. },
  109. });
  110. return stream;
  111. }