From 621c73677a1b9e3c4d0b21a3034a1a3ea5468733 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 20:48:14 +0200 Subject: [PATCH 01/52] fix: bugs --- src/evaluator/LlamaContext/LlamaContext.ts | 225 ++++++++++++--------- 1 file changed, 132 insertions(+), 93 deletions(-) diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index f142b52e..3126cca3 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -190,18 +190,22 @@ export class LlamaContext { this._dispatchDecodeScheduled = false; this._batchDispatchPending = false; - let prioritizeStrategy: ReturnType; - try { - this._ensureNotDisposed(); - prioritizeStrategy = resolveBatchItemsPrioritizingStrategy(this._batchingOptions.itemsPrioritizingStrategy); - } catch (err) { - this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); - return; - } + let shouldHaveAnotherLoop = this._queuedDecodes.length > 0; - let shouldHaveAnotherBatch = this._queuedDecodes.length > 0; + const resolvePrioritizingStrategy = () => { + try { + this._ensureNotDisposed(); + return resolveBatchItemsPrioritizingStrategy(this._batchingOptions.itemsPrioritizingStrategy); + } catch (err) { + this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); + } + + return null; + }; - while (shouldHaveAnotherBatch) { + const getOrderedQueuedDecodes = ( + prioritizeStrategy: ReturnType + ): null | CurrentBatchItem[] => { const batchItemToQueuedDecodeMap = new Map(); const batchItemsList: BatchItem[] = []; @@ -222,25 +226,10 @@ export class LlamaContext { }); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); - return; + return null; } - let batchTokenSlotsLeft = this._batchSize; - const afterDecodeActions: Array<{ - batchLogitIndex: BatchLogitIndex | undefined, - response: [accept: (res: any) => void, reject: (reason: unknown) => void], - onDone?: (batchLogitIndex: BatchLogitIndex) => any - }> = []; - const queuedDecodesToDelete = new Set(); - const currentQueuedDecodeItems = new Set(); - - const currentBatchItems: Array<{ - queuedDecode: InternalQueuedDecode, - processAmount: number - }> = []; - let currentBatchSize = 0; - - for (const prioritizedItem of prioritizedItems) { + return prioritizedItems.map((prioritizedItem): CurrentBatchItem => { const queuedDecode = batchItemToQueuedDecodeMap.get(prioritizedItem.item); if (queuedDecode == null) @@ -249,95 +238,140 @@ export class LlamaContext { "of the batch item on `item` on `PrioritizedBatchItem` in your custom prioritization strategy" ); - const processAmount = Math.min(queuedDecode.tokens.length, prioritizedItem.processAmount, batchTokenSlotsLeft); + return { + queuedDecode, + processAmount: prioritizedItem.processAmount + }; + }); + }; + + const fitQueuedDecodesToABatch = (queuedDecodes: CurrentBatchItem[], batchSize: number) => { + const currentBatchItems: CurrentBatchItem[] = []; + let currentBatchSize = 0; + let batchTokenSlotsLeft = batchSize; + + for (const {queuedDecode, processAmount} of queuedDecodes) { + const resolvedProcessAmount = Math.min( + processAmount <= 0 ? 1 : processAmount, queuedDecode.tokens.length, batchTokenSlotsLeft + ); + + if (resolvedProcessAmount <= 0) { + if (batchTokenSlotsLeft === 0) + break; - if (processAmount <= 0) continue; + } - batchTokenSlotsLeft -= processAmount; + batchTokenSlotsLeft -= resolvedProcessAmount; + currentBatchSize += resolvedProcessAmount; currentBatchItems.push({ queuedDecode, - processAmount + processAmount: resolvedProcessAmount }); - currentBatchSize += processAmount; } - let preventDisposalHandle: DisposalPreventionHandle; + return { + currentBatchItems, + currentBatchSize + }; + }; - try { - preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle(); - } catch (err) { - this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); - return; + const decodeTokenBatchItems = async (batchItems: CurrentBatchItem[], currentBatchSize: number) => { + const afterDecodeActions: Array<{ + batchLogitIndex: BatchLogitIndex | undefined, + response: [accept: (res: any) => void, reject: (reason: unknown) => void], + onDone?: (batchLogitIndex: BatchLogitIndex) => any + }> = []; + const queuedDecodesToDelete = new Set(); + const currentQueuedDecodeItems = new Set(); + + if (currentBatchSize !== 0) + this._ctx.initBatch(currentBatchSize); + + for (const {queuedDecode, processAmount} of batchItems) { + let batchLogitIndex: ReturnType; + try { + batchLogitIndex = this._ctx.addToBatch( + queuedDecode.sequenceId, + queuedDecode.firstTokenSequenceIndex, + Uint32Array.from(queuedDecode.tokens.slice(0, processAmount)), + queuedDecode.generateLogitAtTheEnd && processAmount === queuedDecode.tokens.length + ); + } catch (err) { + this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err); + continue; + } + currentQueuedDecodeItems.add(queuedDecode); + + if (queuedDecode.tokens.length === processAmount) { + queuedDecodesToDelete.add(queuedDecode); + afterDecodeActions.push({ + batchLogitIndex, + response: queuedDecode.response, + onDone: queuedDecode.onDone + }); + } else { + queuedDecode.tokens = queuedDecode.tokens.slice(processAmount); + queuedDecode.firstTokenSequenceIndex += processAmount; + } + } + + for (let i = 0; i < this._queuedDecodes.length; i++) { + const queuedDecode = this._queuedDecodes[i]; + if (queuedDecodesToDelete.has(queuedDecode)) { + this._queuedDecodes.splice(i, 1); + this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId); + i--; + } } try { if (currentBatchSize !== 0) - this._ctx.initBatch(currentBatchSize); + await this._ctx.decodeBatch(); + } catch (err) { + this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err); + return; + } - for (const {queuedDecode, processAmount} of currentBatchItems) { - let batchLogitIndex: ReturnType; + for (const action of afterDecodeActions) { + const [accept, reject] = action.response; + if (action.onDone != null && action.batchLogitIndex != null) { try { - batchLogitIndex = this._ctx.addToBatch( - queuedDecode.sequenceId, - queuedDecode.firstTokenSequenceIndex, - Uint32Array.from(queuedDecode.tokens.slice(0, processAmount)), - queuedDecode.generateLogitAtTheEnd && processAmount === queuedDecode.tokens.length - ); + accept(action.onDone(action.batchLogitIndex ?? null)); } catch (err) { - this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err); - continue; - } - currentQueuedDecodeItems.add(queuedDecode); - - if (queuedDecode.tokens.length === processAmount) { - queuedDecodesToDelete.add(queuedDecode); - afterDecodeActions.push({ - batchLogitIndex, - response: queuedDecode.response, - onDone: queuedDecode.onDone - }); - } else { - queuedDecode.tokens = queuedDecode.tokens.slice(processAmount); - queuedDecode.firstTokenSequenceIndex += processAmount; + reject(err); } - - if (batchTokenSlotsLeft === 0) - break; } - for (let i = 0; i < this._queuedDecodes.length; i++) { - const queuedDecode = this._queuedDecodes[i]; - if (queuedDecodesToDelete.has(queuedDecode)) { - this._queuedDecodes.splice(i, 1); - this._queuedDecodeSequenceIds.delete(queuedDecode.sequenceId); - i--; - } - } + accept(undefined); + } + }; - shouldHaveAnotherBatch = this._queuedDecodes.length > 0; + const prioritizeStrategy = resolvePrioritizingStrategy(); + if (prioritizeStrategy == null) return; // all queued items are rejected and dequeued when we get here - try { - if (currentBatchSize !== 0) - await this._ctx.decodeBatch(); - } catch (err) { - this._dispatchErrorForQueuedDecodesAndDequeue(currentQueuedDecodeItems, err); - return; - } + while (shouldHaveAnotherLoop) { + const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizeStrategy); + if (orderedQueuedDecodes == null) return; // all queued items are rejected and dequeued when we get here - for (const action of afterDecodeActions) { - const [accept, reject] = action.response; - if (action.onDone != null && action.batchLogitIndex != null) { - try { - accept(action.onDone(action.batchLogitIndex ?? null)); - } catch (err) { - reject(err); - } - } + const { + currentBatchItems, + currentBatchSize + } = fitQueuedDecodesToABatch(orderedQueuedDecodes, this._batchSize); - accept(undefined); - } + let preventDisposalHandle: DisposalPreventionHandle; + try { + preventDisposalHandle = this._backendContextDisposeGuard.createPreventDisposalHandle(); + } catch (err) { + this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); + return; + } + + try { + await decodeTokenBatchItems(currentBatchItems, currentBatchSize); + + shouldHaveAnotherLoop = this._queuedDecodes.length > 0; } finally { preventDisposalHandle.dispose(); } @@ -588,7 +622,7 @@ export class LlamaContextSequence { /** * Erase context tokens in the provided ranges to free up space for new tokens to be generated. - * the start and end of each range are exclusive. + * The start of each range is inclusive, and the end of each range is exclusive. * For example, the range `{start: 0, end: 1}` will remove the token at the `0` index only. */ public async eraseContextTokenRanges(ranges: ContextTokensDeleteRange[]) { @@ -976,6 +1010,11 @@ type InternalQueuedDecode = { onDone?: (batchLogitIndex: BatchLogitIndex) => any }; +type CurrentBatchItem = { + queuedDecode: InternalQueuedDecode, + processAmount: number +}; + function disposeContextIfReferenced(contextRef: WeakRef) { const context = contextRef.deref(); From 1ce8cf1631d3df21ed0f29fa7d9310136f533031 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 20:48:52 +0200 Subject: [PATCH 02/52] docs: add canonical URL link --- .vitepress/config.ts | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.vitepress/config.ts b/.vitepress/config.ts index 7f066c5d..247d3c3b 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -108,6 +108,22 @@ export default defineConfig({ pageData.frontmatter.editLink = false; pageData.frontmatter.lastUpdated = false; } + + let canonicalUrl = hostname + pageData.relativePath; + if (canonicalUrl.endsWith("/index.html")) + canonicalUrl = canonicalUrl.slice(0, -"index.html".length); + if (canonicalUrl.endsWith("/index.md")) + canonicalUrl = canonicalUrl.slice(0, -"index.md".length); + else if (canonicalUrl.endsWith(".html")) + canonicalUrl = canonicalUrl.slice(0, -".html".length); + else if (canonicalUrl.endsWith(".md")) + canonicalUrl = canonicalUrl.slice(0, -".md".length); + + pageData.frontmatter.head ??= []; + pageData.frontmatter.head.push([ + "link", + {rel: "canonical", href: canonicalUrl} + ]) }, themeConfig: { editLink: { From 7c333d0bc22db7e935c06da367c60baa7f5294c9 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 21:08:45 +0200 Subject: [PATCH 03/52] test: switch to new vitest test signature --- .../functionary/chatSession.test.ts | 4 +--- .../modelDependent/functionary/embedding.test.ts | 8 ++------ .../modelDependent/functionary/functions.test.ts | 8 ++------ test/modelDependent/functionary/grammar.test.ts | 4 +--- test/modelDependent/functionary/sanity.test.ts | 4 +--- .../stableCode/asyncContextLoad.test.ts | 4 +--- .../stableCode/asyncModelLoad.test.ts | 12 +++--------- .../modelDependent/stableCode/completion.test.ts | 16 ++++------------ test/modelDependent/stableCode/parallel.test.ts | 16 ++++------------ 9 files changed, 19 insertions(+), 57 deletions(-) diff --git a/test/modelDependent/functionary/chatSession.test.ts b/test/modelDependent/functionary/chatSession.test.ts index 6db300ce..28b8d306 100644 --- a/test/modelDependent/functionary/chatSession.test.ts +++ b/test/modelDependent/functionary/chatSession.test.ts @@ -5,7 +5,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("functionary", () => { describe("chat session", () => { - test("restore chat history", async () => { + test("restore chat history", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -34,8 +34,6 @@ describe("functionary", () => { const res2 = await chatSession2.prompt("Repeat your answer"); expect(res2).to.eql("6+6 equals 12."); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/functionary/embedding.test.ts b/test/modelDependent/functionary/embedding.test.ts index 64e97843..aad52cc0 100644 --- a/test/modelDependent/functionary/embedding.test.ts +++ b/test/modelDependent/functionary/embedding.test.ts @@ -4,7 +4,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("functionary", () => { describe("embedding", () => { - test("deterministic", async () => { + test("deterministic", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -25,11 +25,9 @@ describe("functionary", () => { expect(helloWorld2Embedding.vector).to.eql(helloWorldEmbedding.vector); expect(helloWorld2Embedding.vector).to.not.eql(helloThereEmbedding.vector); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("deterministic between runs", async () => { + test("deterministic between runs", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -56,8 +54,6 @@ describe("functionary", () => { expect(helloWorldEmbedding2.vector).to.eql(helloWorldEmbedding.vector); expect(helloThereEmbedding2.vector).to.eql(helloThereEmbedding.vector); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/functionary/functions.test.ts b/test/modelDependent/functionary/functions.test.ts index db561eca..40030256 100644 --- a/test/modelDependent/functionary/functions.test.ts +++ b/test/modelDependent/functionary/functions.test.ts @@ -5,7 +5,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("functionary", () => { describe("functions", () => { - test("get n-th word", async () => { + test("get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -39,13 +39,11 @@ describe("functionary", () => { }); expect(res).to.be.eq('The second word is "secret".'); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); describe("functions and grammar", () => { - test("get n-th word", async () => { + test("get n-th word", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -96,8 +94,6 @@ describe("functionary", () => { const parsedRes2 = res2SchemaGrammar.parse(res2); expect(parsedRes2).to.eql({word: "secret"}); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/functionary/grammar.test.ts b/test/modelDependent/functionary/grammar.test.ts index 8bce36fd..fa4f0655 100644 --- a/test/modelDependent/functionary/grammar.test.ts +++ b/test/modelDependent/functionary/grammar.test.ts @@ -5,7 +5,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("functionary", () => { describe("grammar", () => { - test("JSON schema", async () => { + test("JSON schema", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -41,8 +41,6 @@ describe("functionary", () => { expect(parsedRes.userMessagePositivityScoreFromOneToTen).to.eq(10); expect(parsedRes.verbsInUserMessage).to.eql(["going"]); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/functionary/sanity.test.ts b/test/modelDependent/functionary/sanity.test.ts index 28b4c144..4204b0d2 100644 --- a/test/modelDependent/functionary/sanity.test.ts +++ b/test/modelDependent/functionary/sanity.test.ts @@ -5,7 +5,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("functionary", () => { describe("sanity", () => { - test("How much is 6+6", async () => { + test("How much is 6+6", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); const llama = await getTestLlama(); @@ -22,8 +22,6 @@ describe("functionary", () => { const res = await chatSession.prompt("How much is 6+6"); expect(res).to.eql("6+6 equals 12."); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/stableCode/asyncContextLoad.test.ts b/test/modelDependent/stableCode/asyncContextLoad.test.ts index 89b348f5..065c2d10 100644 --- a/test/modelDependent/stableCode/asyncContextLoad.test.ts +++ b/test/modelDependent/stableCode/asyncContextLoad.test.ts @@ -4,7 +4,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("stableCode", () => { describe("async context load", () => { - test("load asynchronously", async () => { + test("load asynchronously", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -52,8 +52,6 @@ describe("stableCode", () => { expect(loopIterationsBeforeUnload).toBeGreaterThanOrEqual(2); expect(disposePromise).resolves.toBeUndefined(); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/stableCode/asyncModelLoad.test.ts b/test/modelDependent/stableCode/asyncModelLoad.test.ts index c6a01bf2..cd882b13 100644 --- a/test/modelDependent/stableCode/asyncModelLoad.test.ts +++ b/test/modelDependent/stableCode/asyncModelLoad.test.ts @@ -4,7 +4,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("stableCode", () => { describe("async model load", () => { - test("load asynchronously", async () => { + test("load asynchronously", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -49,11 +49,9 @@ describe("stableCode", () => { expect(loopIterationsBeforeUnload).toBeGreaterThanOrEqual(2); expect(disposePromise).resolves.toBeUndefined(); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("load progress emitted", async () => { + test("load progress emitted", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -88,11 +86,9 @@ describe("stableCode", () => { const model = await modelPromise; await model.dispose(); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("abort model load works", async () => { + test("abort model load works", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -151,8 +147,6 @@ describe("stableCode", () => { expect(logProgresses[logProgresses.length - 1]).to.not.be.eql(1); expect(Math.max(...logProgressesAfterAbort)).toBeLessThan(0.8); } - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/stableCode/completion.test.ts b/test/modelDependent/stableCode/completion.test.ts index dbd74d28..8028ddf5 100644 --- a/test/modelDependent/stableCode/completion.test.ts +++ b/test/modelDependent/stableCode/completion.test.ts @@ -5,7 +5,7 @@ import {getTestLlama} from "../../utils/getTestLlama.js"; describe("stableCode", () => { describe("completion", () => { - test("complete a series", async () => { + test("complete a series", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -24,11 +24,9 @@ describe("stableCode", () => { }); const expectedFullCompletion = " " + range(4, 20).join(", "); expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("complete pretictable text", async () => { + test("complete pretictable text", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -47,13 +45,11 @@ describe("stableCode", () => { }); const expectedFullCompletion = " going?"; expect(res.slice(0, expectedFullCompletion.length)).to.eql(expectedFullCompletion); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); describe("infill", () => { - test("fill the gap in a series", async () => { + test("fill the gap in a series", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -71,11 +67,9 @@ describe("stableCode", () => { maxTokens: 20 }); expect(res).to.eql(range(4, 9).join(", ") + ", "); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("fill expected text", async () => { + test("fill expected text", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -93,8 +87,6 @@ describe("stableCode", () => { maxTokens: 10 }); expect(res).to.eql("are you"); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); diff --git a/test/modelDependent/stableCode/parallel.test.ts b/test/modelDependent/stableCode/parallel.test.ts index c436df8a..f600c8fd 100644 --- a/test/modelDependent/stableCode/parallel.test.ts +++ b/test/modelDependent/stableCode/parallel.test.ts @@ -5,7 +5,7 @@ import {createTestLlama, getTestLlama} from "../../utils/getTestLlama.js"; describe("stableCode", () => { describe("parallel", () => { - test("can use multiple bindings in parallel", async () => { + test("can use multiple bindings in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); const llama2 = await createTestLlama(); @@ -53,11 +53,9 @@ describe("stableCode", () => { expect(model2.disposed).toBe(true); expect(context2.disposed).toBe(true); expect(completion2.disposed).toBe(true); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("can use multiple models in parallel", async () => { + test("can use multiple models in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -99,11 +97,9 @@ describe("stableCode", () => { const expectedFullCompletion2 = " " + range(96, 1).join(", "); expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("can use multiple contexts in parallel", async () => { + test("can use multiple contexts in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -142,11 +138,9 @@ describe("stableCode", () => { const expectedFullCompletion2 = " " + range(96, 1).join(", "); expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); - }, { - timeout: 1000 * 60 * 60 * 2 }); - test("can use multiple context sequences in parallel", async () => { + test("can use multiple context sequences in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -183,8 +177,6 @@ describe("stableCode", () => { const expectedFullCompletion2 = " " + range(96, 1).join(", "); expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); - }, { - timeout: 1000 * 60 * 60 * 2 }); }); }); From 13e1ad64c66d478512c339d7a84ecc803e5f16a5 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 21:15:25 +0200 Subject: [PATCH 04/52] test: separate gguf tests to model dependent and model independent tests --- .../__snapshots__/gguf.test.ts.snap | 34 ----------------- .../functionary}/gguf.test.ts | 37 ++++--------------- .../gguf/__snapshots__/gguf.test.ts.snap | 35 ++++++++++++++++++ test/standalone/gguf/gguf.test.ts | 25 +++++++++++++ 4 files changed, 68 insertions(+), 63 deletions(-) rename test/{gguf => modelDependent/functionary}/__snapshots__/gguf.test.ts.snap (84%) rename test/{gguf => modelDependent/functionary}/gguf.test.ts (55%) create mode 100644 test/standalone/gguf/__snapshots__/gguf.test.ts.snap create mode 100644 test/standalone/gguf/gguf.test.ts diff --git a/test/gguf/__snapshots__/gguf.test.ts.snap b/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap similarity index 84% rename from test/gguf/__snapshots__/gguf.test.ts.snap rename to test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap index adfa5a67..21411c0c 100644 --- a/test/gguf/__snapshots__/gguf.test.ts.snap +++ b/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap @@ -143,37 +143,3 @@ exports[`GGUF Parser > should parse local gguf model 1`] = ` "metadataSize": 718762, } `; - -exports[`GGUF Parser > should parse remote gguf model 1`] = ` -{ - "metadata": { - "falcon": { - "attention": { - "head_count": 232, - "head_count_kv": 8, - "layer_norm_epsilon": 0.000009999999747378752, - }, - "block_count": 80, - "context_length": 2048, - "embedding_length": 14848, - "feed_forward_length": 59392, - "tensor_data_layout": "jploski", - }, - "general": { - "architecture": "falcon", - "file_type": "MOSTLY_Q6_K", - "name": "Falcon", - "quantization_version": 2, - }, - "tensorCount": 644, - "tokenizer": { - "ggml": { - "eos_token_id": 11, - "model": "gpt2", - }, - }, - "version": 2, - }, - "metadataSize": 2547826, -} -`; diff --git a/test/gguf/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts similarity index 55% rename from test/gguf/gguf.test.ts rename to test/modelDependent/functionary/gguf.test.ts index b7ad2bfd..f519bcc3 100644 --- a/test/gguf/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,18 +1,15 @@ -import {describe, expect, it} from "vitest"; -import GGUFReadStream from "../../src/gguf/ggufParser/stream/GGUFReadStream.js"; -import GGUFParser from "../../src/gguf/ggufParser/GGUFParser.js"; -import GGUFFetchStream from "../../src/gguf/ggufParser/stream/GGUFFetchStream.js"; -import {getModelFile} from "../utils/modelFiles.js"; -import GGUFInsights from "../../src/gguf/GGUFInsights.js"; -import {getTestLlama} from "../utils/getTestLlama.js"; -import GGUFMetadata from "../../src/gguf/GGUFMetadata.js"; - -const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; +import {describe, expect, it, test} from "vitest"; +import GGUFReadStream from "../../../src/gguf/ggufParser/stream/GGUFReadStream.js"; +import GGUFParser from "../../../src/gguf/ggufParser/GGUFParser.js"; +import {getModelFile} from "../../utils/modelFiles.js"; +import GGUFInsights from "../../../src/gguf/GGUFInsights.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; +import GGUFMetadata from "../../../src/gguf/GGUFMetadata.js"; describe("GGUF Parser", async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); - it("Magic should be GGUF local model", async () => { + test("Magic should be GGUF local model", async () => { const stream = new GGUFReadStream(modelPath); const magic = await stream.readNBytes(4); const magicText = String.fromCharCode(...magic); @@ -29,24 +26,6 @@ describe("GGUF Parser", async () => { expect(metadata).toMatchSnapshot(); }); - it("Magic should be GGUF remote model", async () => { - const stream = new GGUFFetchStream(remoteGGUFModel); - - const magic = await stream.readNBytes(4); - const magicText = String.fromCharCode(...magic); - - expect(magicText).toBe("GGUF"); - }); - - it("should parse remote gguf model", async () => { - const stream = new GGUFFetchStream(remoteGGUFModel); - - const ggufParser = new GGUFParser(stream); - const metadata = await ggufParser.parseMetadata(); - - expect(metadata).toMatchSnapshot(); - }, {timeout: 0}); - it("should calculate GGUF VRAM Usage", async () => { const stream = new GGUFReadStream(modelPath); diff --git a/test/standalone/gguf/__snapshots__/gguf.test.ts.snap b/test/standalone/gguf/__snapshots__/gguf.test.ts.snap new file mode 100644 index 00000000..f91e892d --- /dev/null +++ b/test/standalone/gguf/__snapshots__/gguf.test.ts.snap @@ -0,0 +1,35 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`GGUF Parser > should parse remote gguf model 1`] = ` +{ + "metadata": { + "falcon": { + "attention": { + "head_count": 232, + "head_count_kv": 8, + "layer_norm_epsilon": 0.000009999999747378752, + }, + "block_count": 80, + "context_length": 2048, + "embedding_length": 14848, + "feed_forward_length": 59392, + "tensor_data_layout": "jploski", + }, + "general": { + "architecture": "falcon", + "file_type": "MOSTLY_Q6_K", + "name": "Falcon", + "quantization_version": 2, + }, + "tensorCount": 644, + "tokenizer": { + "ggml": { + "eos_token_id": 11, + "model": "gpt2", + }, + }, + "version": 2, + }, + "metadataSize": 2547826, +} +`; diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts new file mode 100644 index 00000000..81070908 --- /dev/null +++ b/test/standalone/gguf/gguf.test.ts @@ -0,0 +1,25 @@ +import {describe, expect, it, test} from "vitest"; +import GGUFParser from "../../../src/gguf/ggufParser/GGUFParser.js"; +import GGUFFetchStream from "../../../src/gguf/ggufParser/stream/GGUFFetchStream.js"; + +const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; + +describe("GGUF Parser", async () => { + test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { + const stream = new GGUFFetchStream(remoteGGUFModel); + + const magic = await stream.readNBytes(4); + const magicText = String.fromCharCode(...magic); + + expect(magicText).toBe("GGUF"); + }); + + it("should parse remote gguf model", async () => { + const stream = new GGUFFetchStream(remoteGGUFModel); + + const ggufParser = new GGUFParser(stream); + const metadata = await ggufParser.parseMetadata(); + + expect(metadata).toMatchSnapshot(); + }); +}); From 3fc475bdabcdd12ec02ef87e7ef1a9f11147d395 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 21:18:07 +0200 Subject: [PATCH 05/52] chore: update `.gitignore` --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index a233aac0..9a82b9c8 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ node_modules /.eslintcache /.vitepress/.cache /test/.models +/test/temp /coverage /llama/compile_commands.json From 45edf6a058585c924e94d79614c05ace55e5c703 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 21:35:24 +0200 Subject: [PATCH 06/52] feat: add disabled recursive clone feature --- src/bindings/utils/cloneLlamaCppRepo.ts | 8 ++- src/config.ts | 6 ++ src/utils/gitReleaseBundles.ts | 94 ++++++++++++++++++++++++- 3 files changed, 105 insertions(+), 3 deletions(-) diff --git a/src/bindings/utils/cloneLlamaCppRepo.ts b/src/bindings/utils/cloneLlamaCppRepo.ts index 75ab85b2..d12bb3f8 100644 --- a/src/bindings/utils/cloneLlamaCppRepo.ts +++ b/src/bindings/utils/cloneLlamaCppRepo.ts @@ -4,7 +4,9 @@ import cliProgress from "cli-progress"; import chalk from "chalk"; import fs from "fs-extra"; import which from "which"; -import {defaultLlamaCppGitHubRepo, defaultLlamaCppRelease, llamaCppDirectory, llamaCppDirectoryInfoFilePath} from "../../config.js"; +import { + defaultLlamaCppGitHubRepo, defaultLlamaCppRelease, enableRecursiveClone, llamaCppDirectory, llamaCppDirectoryInfoFilePath +} from "../../config.js"; import {getGitBundlePathForRelease} from "../../utils/gitReleaseBundles.js"; import {withLockfile} from "../../utils/withLockfile.js"; import {waitForLockfileRelease} from "../../utils/waitForLockfileRelease.js"; @@ -21,7 +23,8 @@ type ClonedLlamaCppRepoTagFile = { export async function cloneLlamaCppRepo( - githubOwner: string, githubRepo: string, tag: string, useBundles: boolean = true, progressLogs: boolean = true + githubOwner: string, githubRepo: string, tag: string, useBundles: boolean = true, progressLogs: boolean = true, + recursive: boolean = enableRecursiveClone ) { const gitBundleForTag = !useBundles ? null : await getGitBundlePathForRelease(githubOwner, githubRepo, tag); const remoteGitUrl = `https://github.com/${githubOwner}/${githubRepo}.git`; @@ -96,6 +99,7 @@ export async function cloneLlamaCppRepo( await gitWithCloneProgress.clone(remoteGitUrl, llamaCppDirectory, { "--depth": 1, "--branch": tag, + ...(recursive ? {"--recursive": null} : {}), "--quiet": null }); }); diff --git a/src/config.ts b/src/config.ts index 39c6dc9e..56508e98 100644 --- a/src/config.ts +++ b/src/config.ts @@ -78,6 +78,12 @@ export const defaultChatSystemPrompt = "You are a helpful, respectful and honest export const cliBinName = "node-llama-cpp"; export const npxRunPrefix = "npx --no "; +// No need for that at the moment. +// Disabled due to a recursive clone of the llama.cpp repo taking up a lot of space (in the embedded bundle) +// and due to making the clone significantly slower. +// The submodules of the repo are not being used for the compilation for the supported backends, so there's no need to clone them. +export const enableRecursiveClone = false; + const documentationUrl = "https://withcatai.github.io/node-llama-cpp"; export const documentationPageUrls = { CUDA: documentationUrl + "/guide/CUDA", diff --git a/src/utils/gitReleaseBundles.ts b/src/utils/gitReleaseBundles.ts index 2615caf0..2e13508a 100644 --- a/src/utils/gitReleaseBundles.ts +++ b/src/utils/gitReleaseBundles.ts @@ -1,11 +1,19 @@ +import path from "path"; import fs from "fs-extra"; import simpleGit from "simple-git"; -import {currentReleaseGitBundlePath, builtinLlamaCppGitHubRepo, llamaCppDirectory} from "../config.js"; +import {currentReleaseGitBundlePath, builtinLlamaCppGitHubRepo, llamaCppDirectory, enableRecursiveClone} from "../config.js"; import {getBinariesGithubRelease} from "../bindings/utils/binariesGithubRelease.js"; import {isGithubReleaseNeedsResolving} from "./resolveGithubRelease.js"; export async function unshallowAndSquashCurrentRepoAndSaveItAsReleaseBundle() { + if (enableRecursiveClone) + await unshallowAndSquashCurrentRepoWithSubmodulesAndSaveItAsReleaseBundle(); + else + await unshallowAndSquashCurrentRepoWithoutSubmodulesAndSaveItAsReleaseBundle(); +} + +async function unshallowAndSquashCurrentRepoWithoutSubmodulesAndSaveItAsReleaseBundle() { if (!(await fs.pathExists(llamaCppDirectory))) throw new Error("llama.cpp directory does not exist"); @@ -52,6 +60,53 @@ export async function unshallowAndSquashCurrentRepoAndSaveItAsReleaseBundle() { await simpleGit(llamaCppDirectory).raw(["bundle", "create", currentReleaseGitBundlePath, "HEAD"]); } +async function unshallowAndSquashCurrentRepoWithSubmodulesAndSaveItAsReleaseBundle() { + if (!(await fs.pathExists(llamaCppDirectory))) + throw new Error("llama.cpp directory does not exist"); + + if (await fs.pathExists(currentReleaseGitBundlePath)) + await fs.remove(currentReleaseGitBundlePath); + + const currentBranch = await getCurrentTagOrBranch(); + + const lastCommit = await simpleGit(llamaCppDirectory).log(["-1"]); + const lastCommitMessage: string | null = lastCommit?.all?.[0]?.message; + const newCommitMessage = "## SQUASHED ##\n\n" + (lastCommitMessage ?? ""); + const currentRemoteUrl = (await simpleGit(llamaCppDirectory).listRemote(["--get-url", "origin"])).trim(); + + await deleteFilesRecursively(llamaCppDirectory, [".git", ".gitmodules"]); + + await simpleGit(llamaCppDirectory).init(); + await simpleGit(llamaCppDirectory).addConfig("user.name", "node-llama-cpp-ci"); + await simpleGit(llamaCppDirectory).addConfig("user.email", "node-llama-cpp-ci@node-llama-cpp-ci.node-llama-cpp-ci"); + + await simpleGit(llamaCppDirectory).addRemote("origin", currentRemoteUrl); + + await simpleGit(llamaCppDirectory).add([ + "--force", + ...(await getAllFilePaths(llamaCppDirectory, (fileName) => fileName !== ".gitignore")) + ]); + await simpleGit(llamaCppDirectory).commit(newCommitMessage); + + await simpleGit(llamaCppDirectory).add([ + "--force", + ...(await getAllFilePaths(llamaCppDirectory, (fileName) => fileName === ".gitignore")) + ]); + await simpleGit(llamaCppDirectory).commit(newCommitMessage); + + await simpleGit(llamaCppDirectory).branch(["-M", "master"]); + + const newCommitSha = await simpleGit(llamaCppDirectory).raw(["commit-tree", "HEAD^{tree}", "-m", newCommitMessage]); + await simpleGit(llamaCppDirectory).reset(["--hard", newCommitSha.trim()]); + + if (currentBranch != null) + await simpleGit(llamaCppDirectory).tag([currentBranch]); + + await simpleGit(llamaCppDirectory).raw(["gc", "--aggressive", "--prune=all"]); + + await simpleGit(llamaCppDirectory).raw(["bundle", "create", currentReleaseGitBundlePath, "HEAD"]); +} + export async function getGitBundlePathForRelease(githubOwner: string, githubRepo: string, release: string) { const [builtinGithubOwner, builtinGithubRepo] = builtinLlamaCppGitHubRepo.split("/"); if (githubOwner !== builtinGithubOwner || githubRepo !== builtinGithubRepo) @@ -85,3 +140,40 @@ async function getCurrentTagOrBranch() { return null; } + +async function deleteFilesRecursively(folderPath: string, deleteFileOrFolderNames: string[]) { + await Promise.all( + (await fs.readdir(folderPath)) + .map(async (item) => { + const itemPath = path.join(folderPath, item); + + if (deleteFileOrFolderNames.includes(item)) { + // deleting a ".git" folder fails, so we rename it first + const tempNewPath = path.join(folderPath, item + ".deleteme"); + await fs.move(itemPath, tempNewPath); + await fs.remove(tempNewPath); + } else if ((await fs.stat(itemPath)).isDirectory()) + await deleteFilesRecursively(itemPath, deleteFileOrFolderNames); + }) + ); +} + +async function getAllFilePaths(folderPath: string, includePath: (fileName: string) => boolean): Promise { + return ( + await Promise.all( + (await fs.readdir(folderPath)) + .map(async (item) => { + const itemPath = path.join(folderPath, item); + const isDirectory = (await fs.stat(itemPath)).isDirectory(); + + if (isDirectory) + return await getAllFilePaths(itemPath, includePath); + else if (includePath(item)) + return [itemPath]; + + return []; + }) + ) + ) + .flat(); +} From 6f1abcb3126e2f6276cec09f649951e935a141b4 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Mon, 18 Mar 2024 22:55:45 +0200 Subject: [PATCH 07/52] chore: update `vitest` --- package-lock.json | 407 +++++++++++++++++++++++++++++++--------------- package.json | 8 +- 2 files changed, 284 insertions(+), 131 deletions(-) diff --git a/package-lock.json b/package-lock.json index 4789edaa..e326c592 100644 --- a/package-lock.json +++ b/package-lock.json @@ -44,7 +44,7 @@ "@types/cli-progress": "^3.11.0", "@types/cross-spawn": "^6.0.2", "@types/fs-extra": "^11.0.4", - "@types/node": "^20.8.4", + "@types/node": "^20.11.29", "@types/proper-lockfile": "^4.1.4", "@types/semver": "^7.5.8", "@types/uuid": "^9.0.2", @@ -52,8 +52,8 @@ "@types/yargs": "^17.0.24", "@typescript-eslint/eslint-plugin": "^6.3.0", "@typescript-eslint/parser": "^6.3.0", - "@vitest/coverage-v8": "^1.2.2", - "@vitest/ui": "^1.2.2", + "@vitest/coverage-v8": "^1.4.0", + "@vitest/ui": "^1.4.0", "eslint": "^8.46.0", "eslint-plugin-import": "^2.28.0", "eslint-plugin-jsdoc": "^46.9.0", @@ -70,7 +70,7 @@ "typescript": "^5.2.2", "vite-node": "^1.4.0", "vitepress": "1.0.0-rc.22", - "vitest": "^1.2.2", + "vitest": "^1.4.0", "zx": "^7.2.3" }, "engines": { @@ -1561,9 +1561,9 @@ "dev": true }, "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.20", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.20.tgz", - "integrity": "sha512-R8LcPeWZol2zR8mmH3JeKQ6QRCFb7XgUhV9ZlGhHLGyg4wpPiPZNQOOWhFZhxKw8u//yTbNGI42Bx/3paXEQ+Q==", + "version": "0.3.25", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.25.tgz", + "integrity": "sha512-vNk6aEwybGtawWmy/PzwnGDOjCkLWSD2wqvjGGAgOAwCGWySYXfYoxt00IJkTF+8Lb57DwOb3Aa0o9CApepiYQ==", "dev": true, "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -2001,9 +2001,9 @@ } }, "node_modules/@polka/url": { - "version": "1.0.0-next.24", - "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.24.tgz", - "integrity": "sha512-2LuNTFBIO0m7kKIQvvPHN6UE63VjpmL9rnEEaOOaiSPbZK+zUOYIzBAWcED+3XYzhYsd/0mD57VdxAEqqV52CQ==", + "version": "1.0.0-next.25", + "resolved": "https://registry.npmjs.org/@polka/url/-/url-1.0.0-next.25.tgz", + "integrity": "sha512-j7P6Rgr3mmtdkeDGTe0E/aYyWEWVtc5yFXtHCRHs28/jptDEWfaVOc5T7cblqy1XKPPfCxJc/8DwQ5YgLOZOVQ==", "dev": true }, "node_modules/@rollup/rollup-android-arm-eabi": { @@ -3146,9 +3146,9 @@ "dev": true }, "node_modules/@types/node": { - "version": "20.10.0", - "resolved": "https://registry.npmjs.org/@types/node/-/node-20.10.0.tgz", - "integrity": "sha512-D0WfRmU9TQ8I9PFx9Yc+EBHw+vSpIub4IDvQivcp26PtPrdMGAq5SDcpXEo/epqa/DXotVpekHiLNTg3iaKXBQ==", + "version": "20.11.29", + "resolved": "https://registry.npmjs.org/@types/node/-/node-20.11.29.tgz", + "integrity": "sha512-P99thMkD/1YkCvAtOd6/zGedKNA0p2fj4ZpjCzcNiSCBWgm3cNRTBfa/qjFnsKkkojxu4vVLtWpesnZ9+ap+gA==", "dependencies": { "undici-types": "~5.26.4" } @@ -3415,9 +3415,9 @@ "dev": true }, "node_modules/@vitest/coverage-v8": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-1.2.2.tgz", - "integrity": "sha512-IHyKnDz18SFclIEEAHb9Y4Uxx0sPKC2VO1kdDCs1BF6Ip4S8rQprs971zIsooLUn7Afs71GRxWMWpkCGZpRMhw==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-1.4.0.tgz", + "integrity": "sha512-4hDGyH1SvKpgZnIByr9LhGgCEuF9DKM34IBLCC/fVfy24Z3+PZ+Ii9hsVBsHvY1umM1aGPEjceRkzxCfcQ10wg==", "dev": true, "dependencies": { "@ampproject/remapping": "^2.2.1", @@ -3425,12 +3425,13 @@ "debug": "^4.3.4", "istanbul-lib-coverage": "^3.2.2", "istanbul-lib-report": "^3.0.1", - "istanbul-lib-source-maps": "^4.0.1", + "istanbul-lib-source-maps": "^5.0.4", "istanbul-reports": "^3.1.6", "magic-string": "^0.30.5", "magicast": "^0.3.3", "picocolors": "^1.0.0", "std-env": "^3.5.0", + "strip-literal": "^2.0.0", "test-exclude": "^6.0.0", "v8-to-istanbul": "^9.2.0" }, @@ -3438,17 +3439,17 @@ "url": "https://opencollective.com/vitest" }, "peerDependencies": { - "vitest": "^1.0.0" + "vitest": "1.4.0" } }, "node_modules/@vitest/expect": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-1.2.2.tgz", - "integrity": "sha512-3jpcdPAD7LwHUUiT2pZTj2U82I2Tcgg2oVPvKxhn6mDI2On6tfvPQTjAI4628GUGDZrCm4Zna9iQHm5cEexOAg==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-1.4.0.tgz", + "integrity": "sha512-Jths0sWCJZ8BxjKe+p+eKsoqev1/T8lYcrjavEaz8auEJ4jAVY0GwW3JKmdVU4mmNPLPHixh4GNXP7GFtAiDHA==", "dev": true, "dependencies": { - "@vitest/spy": "1.2.2", - "@vitest/utils": "1.2.2", + "@vitest/spy": "1.4.0", + "@vitest/utils": "1.4.0", "chai": "^4.3.10" }, "funding": { @@ -3456,12 +3457,12 @@ } }, "node_modules/@vitest/runner": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-1.2.2.tgz", - "integrity": "sha512-JctG7QZ4LSDXr5CsUweFgcpEvrcxOV1Gft7uHrvkQ+fsAVylmWQvnaAr/HDp3LAH1fztGMQZugIheTWjaGzYIg==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/runner/-/runner-1.4.0.tgz", + "integrity": "sha512-EDYVSmesqlQ4RD2VvWo3hQgTJ7ZrFQ2VSJdfiJiArkCerDAGeyF1i6dHkmySqk573jLp6d/cfqCN+7wUB5tLgg==", "dev": true, "dependencies": { - "@vitest/utils": "1.2.2", + "@vitest/utils": "1.4.0", "p-limit": "^5.0.0", "pathe": "^1.1.1" }, @@ -3497,9 +3498,9 @@ } }, "node_modules/@vitest/snapshot": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-1.2.2.tgz", - "integrity": "sha512-SmGY4saEw1+bwE1th6S/cZmPxz/Q4JWsl7LvbQIky2tKE35US4gd0Mjzqfr84/4OD0tikGWaWdMja/nWL5NIPA==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/snapshot/-/snapshot-1.4.0.tgz", + "integrity": "sha512-saAFnt5pPIA5qDGxOHxJ/XxhMFKkUSBJmVt5VgDsAqPTX6JP326r5C/c9UuCMPoXNzuudTPsYDZCoJ5ilpqG2A==", "dev": true, "dependencies": { "magic-string": "^0.30.5", @@ -3511,9 +3512,9 @@ } }, "node_modules/@vitest/spy": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-1.2.2.tgz", - "integrity": "sha512-k9Gcahssw8d7X3pSLq3e3XEu/0L78mUkCjivUqCQeXJm9clfXR/Td8+AP+VC1O6fKPIDLcHDTAmBOINVuv6+7g==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/spy/-/spy-1.4.0.tgz", + "integrity": "sha512-Ywau/Qs1DzM/8Uc+yA77CwSegizMlcgTJuYGAi0jujOteJOUf1ujunHThYo243KG9nAyWT3L9ifPYZ5+As/+6Q==", "dev": true, "dependencies": { "tinyspy": "^2.2.0" @@ -3523,12 +3524,12 @@ } }, "node_modules/@vitest/ui": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/ui/-/ui-1.2.2.tgz", - "integrity": "sha512-CG+5fa8lyoBr+9i+UZGS31Qw81v33QlD10uecHxN2CLJVN+jLnqx4pGzGvFFeJ7jSnUCT0AlbmVWY6fU6NJZmw==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/ui/-/ui-1.4.0.tgz", + "integrity": "sha512-XC6CMhN1gzYcGbpn6/Oanj4Au2EXwQEX6vpcOeLlZv8dy7g11Ukx8zwtYQbwxs9duK2s9j2o5rbQiCP5DPAcmw==", "dev": true, "dependencies": { - "@vitest/utils": "1.2.2", + "@vitest/utils": "1.4.0", "fast-glob": "^3.3.2", "fflate": "^0.8.1", "flatted": "^3.2.9", @@ -3540,13 +3541,13 @@ "url": "https://opencollective.com/vitest" }, "peerDependencies": { - "vitest": "^1.0.0" + "vitest": "1.4.0" } }, "node_modules/@vitest/utils": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-1.2.2.tgz", - "integrity": "sha512-WKITBHLsBHlpjnDQahr+XK6RE7MiAsgrIkr0pGhQ9ygoxBfUeG0lUG5iLlzqjmKSlBv3+j5EGsriBzh+C3Tq9g==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/@vitest/utils/-/utils-1.4.0.tgz", + "integrity": "sha512-mx3Yd1/6e2Vt/PUC98DcqTirtfxUyAZ32uK82r8rZzbtBeBo+nqgnjx/LvqQdWsrvNtm14VmurNgcf4nqY5gJg==", "dev": true, "dependencies": { "diff-sequences": "^29.6.3", @@ -6138,9 +6139,9 @@ } }, "node_modules/fflate": { - "version": "0.8.1", - "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.1.tgz", - "integrity": "sha512-/exOvEuc+/iaUm105QIiOt4LpBdMTWsXxqR0HDF35vx3fmaKzw7354gTilCh5rkzEt8WYyG//ku3h3nRmd7CHQ==", + "version": "0.8.2", + "resolved": "https://registry.npmjs.org/fflate/-/fflate-0.8.2.tgz", + "integrity": "sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==", "dev": true }, "node_modules/figures": { @@ -7653,14 +7654,14 @@ } }, "node_modules/istanbul-lib-source-maps": { - "version": "4.0.1", - "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-4.0.1.tgz", - "integrity": "sha512-n3s8EwkdFIJCG3BPKBYvskgXGoy88ARzvegkitk60NxRdwltLOTaH7CUiMRXvwYorl0Q712iEjcWB+fK/MrWVw==", + "version": "5.0.4", + "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-5.0.4.tgz", + "integrity": "sha512-wHOoEsNJTVltaJp8eVkm8w+GVkVNHT2YDYo53YdzQEL2gWm1hBX5cGFR9hQJtuGLebidVX7et3+dmDZrmclduw==", "dev": true, "dependencies": { + "@jridgewell/trace-mapping": "^0.3.23", "debug": "^4.1.1", - "istanbul-lib-coverage": "^3.0.0", - "source-map": "^0.6.1" + "istanbul-lib-coverage": "^3.0.0" }, "engines": { "node": ">=10" @@ -8450,9 +8451,9 @@ } }, "node_modules/mlly": { - "version": "1.5.0", - "resolved": "https://registry.npmjs.org/mlly/-/mlly-1.5.0.tgz", - "integrity": "sha512-NPVQvAY1xr1QoVeG0cy8yUYC7FQcOx6evl/RjT1wL5FvzPnzOysoqB/jmx/DhssT2dYa8nxECLAaFI/+gVLhDQ==", + "version": "1.6.1", + "resolved": "https://registry.npmjs.org/mlly/-/mlly-1.6.1.tgz", + "integrity": "sha512-vLgaHvaeunuOXHSmEbZ9izxPx3USsk8KCQ8iC+aTlp5sKRSoZvwhHh5L9VbKSaVC6sJDqbyohIS76E2VmHIPAA==", "dev": true, "dependencies": { "acorn": "^8.11.3", @@ -12191,9 +12192,9 @@ } }, "node_modules/postcss": { - "version": "8.4.33", - "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.33.tgz", - "integrity": "sha512-Kkpbhhdjw2qQs2O2DGX+8m5OVqEcbB9HRBvuYM9pgrjEFUg30A9LmXNlTAUj4S9kgtGyrMbTzVjH7E+s5Re2yg==", + "version": "8.4.36", + "resolved": "https://registry.npmjs.org/postcss/-/postcss-8.4.36.tgz", + "integrity": "sha512-/n7eumA6ZjFHAsbX30yhHup/IMkOmlmvtEi7P+6RMYf+bGJSUHc3geH4a0NSZxAz/RJfiS9tooCTs9LAVYUZKw==", "dev": true, "funding": [ { @@ -12212,7 +12213,7 @@ "dependencies": { "nanoid": "^3.3.7", "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" + "source-map-js": "^1.1.0" }, "engines": { "node": "^10 || ^12 || >=14" @@ -13458,9 +13459,9 @@ } }, "node_modules/source-map-js": { - "version": "1.0.2", - "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.0.2.tgz", - "integrity": "sha512-R0XvVJ9WusLiqTCEiGCmICCMplcCkIwwR11mOSD9CR5u+IXYdiseeEuXCVAjS54zqwkLcPNnmU4OeJ6tUrWhDw==", + "version": "1.1.0", + "resolved": "https://registry.npmjs.org/source-map-js/-/source-map-js-1.1.0.tgz", + "integrity": "sha512-9vC2SfsJzlej6MAaMPLu8HiBSHGdRAJ9hVFYN1ibZoNkeanmDmLUcIrj6G9DGL7XMJ54AKg/G75akXl1/izTOw==", "dev": true, "engines": { "node": ">=0.10.0" @@ -13853,17 +13854,23 @@ } }, "node_modules/strip-literal": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/strip-literal/-/strip-literal-1.3.0.tgz", - "integrity": "sha512-PugKzOsyXpArk0yWmUwqOZecSO0GH0bPoctLcqNDH9J04pVW3lflYE0ujElBGTloevcxF5MofAOZ7C5l2b+wLg==", + "version": "2.0.0", + "resolved": "https://registry.npmjs.org/strip-literal/-/strip-literal-2.0.0.tgz", + "integrity": "sha512-f9vHgsCWBq2ugHAkGMiiYY+AYG0D/cbloKKg0nhaaaSNsujdGIpVXCNsrJpCKr5M0f4aI31mr13UjY6GAuXCKA==", "dev": true, "dependencies": { - "acorn": "^8.10.0" + "js-tokens": "^8.0.2" }, "funding": { "url": "https://github.com/sponsors/antfu" } }, + "node_modules/strip-literal/node_modules/js-tokens": { + "version": "8.0.3", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-8.0.3.tgz", + "integrity": "sha512-UfJMcSJc+SEXEl9lH/VLHSZbThQyLpw1vLO1Lb+j4RWDvG3N2f7yj3PVQA3cmkTBNldJ9eFnM+xEXxHIXrYiJw==", + "dev": true + }, "node_modules/supports-color": { "version": "7.2.0", "resolved": "https://registry.npmjs.org/supports-color/-/supports-color-7.2.0.tgz", @@ -14062,9 +14069,9 @@ } }, "node_modules/tinyspy": { - "version": "2.2.0", - "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-2.2.0.tgz", - "integrity": "sha512-d2eda04AN/cPOR89F7Xv5bK/jrQEhmcLFe6HFldoeO9AJtps+fqEnh486vnT/8y4bw38pSyxDcTCAq+Ks2aJTg==", + "version": "2.2.1", + "resolved": "https://registry.npmjs.org/tinyspy/-/tinyspy-2.2.1.tgz", + "integrity": "sha512-KYad6Vy5VDWV4GH3fjpseMQ/XU2BhIYP7Vzd0LG44qRWm/Yt2WCOTicFdvmgo6gWaqooMQCawTtILVQJupKu7A==", "dev": true, "engines": { "node": ">=14.0.0" @@ -14387,9 +14394,9 @@ } }, "node_modules/ufo": { - "version": "1.3.2", - "resolved": "https://registry.npmjs.org/ufo/-/ufo-1.3.2.tgz", - "integrity": "sha512-o+ORpgGwaYQXgqGDwd+hkS4PuZ3QnmqMMxRuajK/a38L6fTpcE5GPIfrf+L/KemFzfUpeUQc1rRS1iDBozvnFA==", + "version": "1.5.2", + "resolved": "https://registry.npmjs.org/ufo/-/ufo-1.5.2.tgz", + "integrity": "sha512-eiutMaL0J2MKdhcOM1tUy13pIrYnyR87fEd8STJQFrrAwImwvlXkxlZEjaKah8r2viPohld08lt73QfLG1NxMg==", "dev": true }, "node_modules/uglify-js": { @@ -15134,18 +15141,17 @@ } }, "node_modules/vitest": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/vitest/-/vitest-1.2.2.tgz", - "integrity": "sha512-d5Ouvrnms3GD9USIK36KG8OZ5bEvKEkITFtnGv56HFaSlbItJuYr7hv2Lkn903+AvRAgSixiamozUVfORUekjw==", + "version": "1.4.0", + "resolved": "https://registry.npmjs.org/vitest/-/vitest-1.4.0.tgz", + "integrity": "sha512-gujzn0g7fmwf83/WzrDTnncZt2UiXP41mHuFYFrdwaLRVQ6JYQEiME2IfEjU3vcFL3VKa75XhI3lFgn+hfVsQw==", "dev": true, "dependencies": { - "@vitest/expect": "1.2.2", - "@vitest/runner": "1.2.2", - "@vitest/snapshot": "1.2.2", - "@vitest/spy": "1.2.2", - "@vitest/utils": "1.2.2", + "@vitest/expect": "1.4.0", + "@vitest/runner": "1.4.0", + "@vitest/snapshot": "1.4.0", + "@vitest/spy": "1.4.0", + "@vitest/utils": "1.4.0", "acorn-walk": "^8.3.2", - "cac": "^6.7.14", "chai": "^4.3.10", "debug": "^4.3.4", "execa": "^8.0.1", @@ -15154,11 +15160,11 @@ "pathe": "^1.1.1", "picocolors": "^1.0.0", "std-env": "^3.5.0", - "strip-literal": "^1.3.0", + "strip-literal": "^2.0.0", "tinybench": "^2.5.1", "tinypool": "^0.8.2", "vite": "^5.0.0", - "vite-node": "1.2.2", + "vite-node": "1.4.0", "why-is-node-running": "^2.2.2" }, "bin": { @@ -15173,8 +15179,8 @@ "peerDependencies": { "@edge-runtime/vm": "*", "@types/node": "^18.0.0 || >=20.0.0", - "@vitest/browser": "^1.0.0", - "@vitest/ui": "^1.0.0", + "@vitest/browser": "1.4.0", + "@vitest/ui": "1.4.0", "happy-dom": "*", "jsdom": "*" }, @@ -15551,6 +15557,175 @@ "node": ">=12" } }, + "node_modules/vitest/node_modules/@rollup/rollup-android-arm-eabi": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm-eabi/-/rollup-android-arm-eabi-4.13.0.tgz", + "integrity": "sha512-5ZYPOuaAqEH/W3gYsRkxQATBW3Ii1MfaT4EQstTnLKViLi2gLSQmlmtTpGucNP3sXEpOiI5tdGhjdE111ekyEg==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-android-arm64": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-android-arm64/-/rollup-android-arm64-4.13.0.tgz", + "integrity": "sha512-BSbaCmn8ZadK3UAQdlauSvtaJjhlDEjS5hEVVIN3A4bbl3X+otyf/kOJV08bYiRxfejP3DXFzO2jz3G20107+Q==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "android" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-darwin-arm64": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-arm64/-/rollup-darwin-arm64-4.13.0.tgz", + "integrity": "sha512-Ovf2evVaP6sW5Ut0GHyUSOqA6tVKfrTHddtmxGQc1CTQa1Cw3/KMCDEEICZBbyppcwnhMwcDce9ZRxdWRpVd6g==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-darwin-x64": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-darwin-x64/-/rollup-darwin-x64-4.13.0.tgz", + "integrity": "sha512-U+Jcxm89UTK592vZ2J9st9ajRv/hrwHdnvyuJpa5A2ngGSVHypigidkQJP+YiGL6JODiUeMzkqQzbCG3At81Gg==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "darwin" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-arm-gnueabihf": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm-gnueabihf/-/rollup-linux-arm-gnueabihf-4.13.0.tgz", + "integrity": "sha512-8wZidaUJUTIR5T4vRS22VkSMOVooG0F4N+JSwQXWSRiC6yfEsFMLTYRFHvby5mFFuExHa/yAp9juSphQQJAijQ==", + "cpu": [ + "arm" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-arm64-gnu": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-gnu/-/rollup-linux-arm64-gnu-4.13.0.tgz", + "integrity": "sha512-Iu0Kno1vrD7zHQDxOmvweqLkAzjxEVqNhUIXBsZ8hu8Oak7/5VTPrxOEZXYC1nmrBVJp0ZcL2E7lSuuOVaE3+w==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-arm64-musl": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-arm64-musl/-/rollup-linux-arm64-musl-4.13.0.tgz", + "integrity": "sha512-C31QrW47llgVyrRjIwiOwsHFcaIwmkKi3PCroQY5aVq4H0A5v/vVVAtFsI1nfBngtoRpeREvZOkIhmRwUKkAdw==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-riscv64-gnu": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-riscv64-gnu/-/rollup-linux-riscv64-gnu-4.13.0.tgz", + "integrity": "sha512-Oq90dtMHvthFOPMl7pt7KmxzX7E71AfyIhh+cPhLY9oko97Zf2C9tt/XJD4RgxhaGeAraAXDtqxvKE1y/j35lA==", + "cpu": [ + "riscv64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-x64-gnu": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-gnu/-/rollup-linux-x64-gnu-4.13.0.tgz", + "integrity": "sha512-yUD/8wMffnTKuiIsl6xU+4IA8UNhQ/f1sAnQebmE/lyQ8abjsVyDkyRkWop0kdMhKMprpNIhPmYlCxgHrPoXoA==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-linux-x64-musl": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-linux-x64-musl/-/rollup-linux-x64-musl-4.13.0.tgz", + "integrity": "sha512-9RyNqoFNdF0vu/qqX63fKotBh43fJQeYC98hCaf89DYQpv+xu0D8QFSOS0biA7cGuqJFOc1bJ+m2rhhsKcw1hw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "linux" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-win32-arm64-msvc": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-arm64-msvc/-/rollup-win32-arm64-msvc-4.13.0.tgz", + "integrity": "sha512-46ue8ymtm/5PUU6pCvjlic0z82qWkxv54GTJZgHrQUuZnVH+tvvSP0LsozIDsCBFO4VjJ13N68wqrKSeScUKdA==", + "cpu": [ + "arm64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-win32-ia32-msvc": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-ia32-msvc/-/rollup-win32-ia32-msvc-4.13.0.tgz", + "integrity": "sha512-P5/MqLdLSlqxbeuJ3YDeX37srC8mCflSyTrUsgbU1c/U9j6l2g2GiIdYaGD9QjdMQPMSgYm7hgg0551wHyIluw==", + "cpu": [ + "ia32" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, + "node_modules/vitest/node_modules/@rollup/rollup-win32-x64-msvc": { + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/@rollup/rollup-win32-x64-msvc/-/rollup-win32-x64-msvc-4.13.0.tgz", + "integrity": "sha512-UKXUQNbO3DOhzLRwHSpa0HnhhCgNODvfoPWv2FCXme8N/ANFfhIPMGuOT+QuKd16+B5yxZ0HdpNlqPvTMS1qfw==", + "cpu": [ + "x64" + ], + "dev": true, + "optional": true, + "os": [ + "win32" + ] + }, "node_modules/vitest/node_modules/esbuild": { "version": "0.19.12", "resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.19.12.tgz", @@ -15658,9 +15833,9 @@ } }, "node_modules/vitest/node_modules/npm-run-path": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-5.2.0.tgz", - "integrity": "sha512-W4/tgAXFqFA0iL7fk0+uQ3g7wkL8xJmx3XdK0VGb4cHW//eZTtKGvFBBoRKVTpY7n6ze4NL9ly7rgXcHufqXKg==", + "version": "5.3.0", + "resolved": "https://registry.npmjs.org/npm-run-path/-/npm-run-path-5.3.0.tgz", + "integrity": "sha512-ppwTtiJZq0O/ai0z7yfudtBpWIoxM8yE6nHi1X47eFR2EWORqfbu6CnPlNsjeN683eT0qG6H/Pyf9fCcvjnnnQ==", "dev": true, "dependencies": { "path-key": "^4.0.0" @@ -15700,9 +15875,9 @@ } }, "node_modules/vitest/node_modules/rollup": { - "version": "4.9.6", - "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.9.6.tgz", - "integrity": "sha512-05lzkCS2uASX0CiLFybYfVkwNbKZG5NFQ6Go0VWyogFTXXbR039UVsegViTntkk4OglHBdF54ccApXRRuXRbsg==", + "version": "4.13.0", + "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.13.0.tgz", + "integrity": "sha512-3YegKemjoQnYKmsBlOHfMLVPPA5xLkQ8MHLLSw/fBrFaVkEayL51DilPpNNLq1exr98F2B1TzrV0FUlN3gWRPg==", "dev": true, "dependencies": { "@types/estree": "1.0.5" @@ -15715,19 +15890,19 @@ "npm": ">=8.0.0" }, "optionalDependencies": { - "@rollup/rollup-android-arm-eabi": "4.9.6", - "@rollup/rollup-android-arm64": "4.9.6", - "@rollup/rollup-darwin-arm64": "4.9.6", - "@rollup/rollup-darwin-x64": "4.9.6", - "@rollup/rollup-linux-arm-gnueabihf": "4.9.6", - "@rollup/rollup-linux-arm64-gnu": "4.9.6", - "@rollup/rollup-linux-arm64-musl": "4.9.6", - "@rollup/rollup-linux-riscv64-gnu": "4.9.6", - "@rollup/rollup-linux-x64-gnu": "4.9.6", - "@rollup/rollup-linux-x64-musl": "4.9.6", - "@rollup/rollup-win32-arm64-msvc": "4.9.6", - "@rollup/rollup-win32-ia32-msvc": "4.9.6", - "@rollup/rollup-win32-x64-msvc": "4.9.6", + "@rollup/rollup-android-arm-eabi": "4.13.0", + "@rollup/rollup-android-arm64": "4.13.0", + "@rollup/rollup-darwin-arm64": "4.13.0", + "@rollup/rollup-darwin-x64": "4.13.0", + "@rollup/rollup-linux-arm-gnueabihf": "4.13.0", + "@rollup/rollup-linux-arm64-gnu": "4.13.0", + "@rollup/rollup-linux-arm64-musl": "4.13.0", + "@rollup/rollup-linux-riscv64-gnu": "4.13.0", + "@rollup/rollup-linux-x64-gnu": "4.13.0", + "@rollup/rollup-linux-x64-musl": "4.13.0", + "@rollup/rollup-win32-arm64-msvc": "4.13.0", + "@rollup/rollup-win32-ia32-msvc": "4.13.0", + "@rollup/rollup-win32-x64-msvc": "4.13.0", "fsevents": "~2.3.2" } }, @@ -15756,13 +15931,13 @@ } }, "node_modules/vitest/node_modules/vite": { - "version": "5.0.12", - "resolved": "https://registry.npmjs.org/vite/-/vite-5.0.12.tgz", - "integrity": "sha512-4hsnEkG3q0N4Tzf1+t6NdN9dg/L3BM+q8SWgbSPnJvrgH2kgdyzfVJwbR1ic69/4uMJJ/3dqDZZE5/WwqW8U1w==", + "version": "5.1.6", + "resolved": "https://registry.npmjs.org/vite/-/vite-5.1.6.tgz", + "integrity": "sha512-yYIAZs9nVfRJ/AiOLCA91zzhjsHUgMjB+EigzFb6W2XTLO8JixBCKCjvhKZaye+NKYHCrkv3Oh50dH9EdLU2RA==", "dev": true, "dependencies": { "esbuild": "^0.19.3", - "postcss": "^8.4.32", + "postcss": "^8.4.35", "rollup": "^4.2.0" }, "bin": { @@ -15810,28 +15985,6 @@ } } }, - "node_modules/vitest/node_modules/vite-node": { - "version": "1.2.2", - "resolved": "https://registry.npmjs.org/vite-node/-/vite-node-1.2.2.tgz", - "integrity": "sha512-1as4rDTgVWJO3n1uHmUYqq7nsFgINQ9u+mRcXpjeOMJUmviqNKjcZB7UfRZrlM7MjYXMKpuWp5oGkjaFLnjawg==", - "dev": true, - "dependencies": { - "cac": "^6.7.14", - "debug": "^4.3.4", - "pathe": "^1.1.1", - "picocolors": "^1.0.0", - "vite": "^5.0.0" - }, - "bin": { - "vite-node": "vite-node.mjs" - }, - "engines": { - "node": "^18.0.0 || >=20.0.0" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, "node_modules/vscode-oniguruma": { "version": "1.7.0", "resolved": "https://registry.npmjs.org/vscode-oniguruma/-/vscode-oniguruma-1.7.0.tgz", diff --git a/package.json b/package.json index dd9fee4d..a2a075e8 100644 --- a/package.json +++ b/package.json @@ -119,7 +119,7 @@ "@types/cli-progress": "^3.11.0", "@types/cross-spawn": "^6.0.2", "@types/fs-extra": "^11.0.4", - "@types/node": "^20.8.4", + "@types/node": "^20.11.29", "@types/proper-lockfile": "^4.1.4", "@types/semver": "^7.5.8", "@types/uuid": "^9.0.2", @@ -127,8 +127,8 @@ "@types/yargs": "^17.0.24", "@typescript-eslint/eslint-plugin": "^6.3.0", "@typescript-eslint/parser": "^6.3.0", - "@vitest/coverage-v8": "^1.2.2", - "@vitest/ui": "^1.2.2", + "@vitest/coverage-v8": "^1.4.0", + "@vitest/ui": "^1.4.0", "eslint": "^8.46.0", "eslint-plugin-import": "^2.28.0", "eslint-plugin-jsdoc": "^46.9.0", @@ -145,7 +145,7 @@ "typescript": "^5.2.2", "vite-node": "^1.4.0", "vitepress": "1.0.0-rc.22", - "vitest": "^1.2.2", + "vitest": "^1.4.0", "zx": "^7.2.3" }, "dependencies": { From 3995ae1b4686aea6dee9ff18ca2a374270c9cdd3 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 19 Mar 2024 00:59:44 +0200 Subject: [PATCH 08/52] refactor: `GgufParser` --- src/gguf/GGUFInsights.ts | 2 +- src/gguf/GGUFMetadata.ts | 31 ++-- ...MagicError.ts => InvalidGgufMagicError.ts} | 2 +- src/gguf/ggufParser/GGUFParser.ts | 162 ----------------- .../{GGUFTypes.ts => GgufMetadataTypes.ts} | 170 ++++++++++++------ src/gguf/ggufParser/GgufParser.ts | 159 ++++++++++++++++ src/gguf/ggufParser/checkArchitecture.ts | 93 ---------- src/gguf/ggufParser/stream/GGUFBaseStream.ts | 102 ----------- src/gguf/ggufParser/stream/GGUFFetchStream.ts | 47 ----- src/gguf/ggufParser/stream/GGUFReadStream.ts | 50 ------ src/gguf/ggufParser/stream/GgufBaseStream.ts | 106 +++++++++++ src/gguf/ggufParser/stream/GgufFetchStream.ts | 62 +++++++ .../ggufParser/stream/GgufFsReadStream.ts | 64 +++++++ src/gguf/ggufParser/utils/GgufReadOffset.ts | 21 +++ .../utils/parseGgufFileTypeNumber.ts | 34 ++++ test/modelDependent/functionary/gguf.test.ts | 16 +- test/standalone/gguf/gguf.test.ts | 12 +- 17 files changed, 595 insertions(+), 538 deletions(-) rename src/gguf/errors/{InvalidGGUFMagicError.ts => InvalidGgufMagicError.ts} (60%) delete mode 100644 src/gguf/ggufParser/GGUFParser.ts rename src/gguf/ggufParser/{GGUFTypes.ts => GgufMetadataTypes.ts} (60%) create mode 100644 src/gguf/ggufParser/GgufParser.ts delete mode 100644 src/gguf/ggufParser/checkArchitecture.ts delete mode 100644 src/gguf/ggufParser/stream/GGUFBaseStream.ts delete mode 100644 src/gguf/ggufParser/stream/GGUFFetchStream.ts delete mode 100644 src/gguf/ggufParser/stream/GGUFReadStream.ts create mode 100644 src/gguf/ggufParser/stream/GgufBaseStream.ts create mode 100644 src/gguf/ggufParser/stream/GgufFetchStream.ts create mode 100644 src/gguf/ggufParser/stream/GgufFsReadStream.ts create mode 100644 src/gguf/ggufParser/utils/GgufReadOffset.ts create mode 100644 src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts diff --git a/src/gguf/GGUFInsights.ts b/src/gguf/GGUFInsights.ts index 409720d1..b8a33467 100644 --- a/src/gguf/GGUFInsights.ts +++ b/src/gguf/GGUFInsights.ts @@ -1,6 +1,6 @@ import {Llama} from "../bindings/Llama.js"; import MissingNodeLlamaError from "./errors/MissingNodeLlamaError.js"; -import {GGUFMetadataResponse} from "./ggufParser/GGUFParser.js"; +import {GGUFMetadataResponse} from "./ggufParser/GgufParser.js"; import NotEnoughVRamError from "./errors/ModelScore/NotEnoughVRamError.js"; const PAD_AVAILABLE_VRAM = 1024 ** 2 * 500; // 500MB diff --git a/src/gguf/GGUFMetadata.ts b/src/gguf/GGUFMetadata.ts index 61520bb4..c8d5796d 100644 --- a/src/gguf/GGUFMetadata.ts +++ b/src/gguf/GGUFMetadata.ts @@ -1,9 +1,9 @@ import retry from "async-retry"; import MetadataNotParsedYetError from "./errors/MetadataNotParsedYetError.js"; import GGUFInsights, {GGUFInsightsOptions} from "./GGUFInsights.js"; -import GGUFParser, {GGUFMetadataResponse} from "./ggufParser/GGUFParser.js"; -import GGUFFetchStream from "./ggufParser/stream/GGUFFetchStream.js"; -import GGUFReadStream from "./ggufParser/stream/GGUFReadStream.js"; +import {GgufParser, GGUFMetadataResponse} from "./ggufParser/GgufParser.js"; +import {GgufFetchStream} from "./ggufParser/stream/GgufFetchStream.js"; +import {GgufFsReadStream} from "./ggufParser/stream/GgufFsReadStream.js"; export type GGUFMetadataOptions = { source?: "network" | "local", @@ -17,6 +17,11 @@ export default class GGUFMetadata { public readonly path: string; public readonly options: Partial = {}; + public constructor(path: string, options: Partial = {}) { + this.options = options; + this.path = path; + } + public get metadata() { if (!this._metadata) { throw new MetadataNotParsedYetError(this.path); @@ -28,24 +33,28 @@ export default class GGUFMetadata { return new GGUFInsights(this.metadata, this.options.insights); } - public constructor(path: string, options: Partial = {}) { - this.options = options; - this.path = path; - } - public async parse() { const stream = this._createStream(); - const parser = new GGUFParser(stream, this.options.ignoreKeys); + const parser = new GgufParser({ + stream, + ignoreKeys: this.options.ignoreKeys + }); return this._metadata = await parser.parseMetadata(); } private _createStream() { switch (this.options.source) { case "network": - return new GGUFFetchStream(this.path, {retry: this.options.retry}); + return new GgufFetchStream({ + url: this.path, + retryOptions: this.options.retry + }); case "local": default: - return new GGUFReadStream(this.path, {retry: this.options.retry}); + return new GgufFsReadStream({ + filePath: this.path, + retryOptions: this.options.retry + }); } } } diff --git a/src/gguf/errors/InvalidGGUFMagicError.ts b/src/gguf/errors/InvalidGgufMagicError.ts similarity index 60% rename from src/gguf/errors/InvalidGGUFMagicError.ts rename to src/gguf/errors/InvalidGgufMagicError.ts index f3ec7f5d..0efe3ef6 100644 --- a/src/gguf/errors/InvalidGGUFMagicError.ts +++ b/src/gguf/errors/InvalidGgufMagicError.ts @@ -1,4 +1,4 @@ -export default class InvalidGGUFMagicError extends Error { +export class InvalidGgufMagicError extends Error { public constructor(message = "Invalid GGUF magic") { super(message); } diff --git a/src/gguf/ggufParser/GGUFParser.ts b/src/gguf/ggufParser/GGUFParser.ts deleted file mode 100644 index aff71de1..00000000 --- a/src/gguf/ggufParser/GGUFParser.ts +++ /dev/null @@ -1,162 +0,0 @@ -import InvalidGGUFMagicError from "../errors/InvalidGGUFMagicError.js"; -import UnsupportedMetadataTypeError from "../errors/UnsupportedMetadataTypeError.js"; -import {fileTypeIntToString} from "./checkArchitecture.js"; -import {GGUFMetadataAny} from "./GGUFTypes.js"; -import GGUFBaseStream, {METHOD_TO_BYTE_COUNT} from "./stream/GGUFBaseStream.js"; - -const METADATA_VALUE_TO_METHOD: { [key: number]: keyof typeof METHOD_TO_BYTE_COUNT } = { - 0: "readUint8", - 1: "readInt8", - 2: "readUint16", - 3: "readInt16", - 4: "readUint32", - 5: "readInt32", - 6: "readFloat32", - 7: "readBool" -}; - -const METADATA_STRING = 8; -const METADATA_ARRAY = 9; - -const GGUF_MAGIC = "GGUF"; - -const DEFAULT_IGNORE_METADATA_KEYS = ["tokenizer.ggml.tokens", "tokenizer.ggml.scores", "tokenizer.ggml.token_type", "tokenizer.ggml.merges"]; - -export type GGUFMetadataResponse = { - metadataSize: number, - metadata: GGUFMetadataAny -}; - -export default class GGUFParser { - protected readonly _stream: GGUFBaseStream; - - public ignoreKeys = DEFAULT_IGNORE_METADATA_KEYS; - - public constructor(_stream: GGUFBaseStream, ignoreKeys = DEFAULT_IGNORE_METADATA_KEYS) { - this.ignoreKeys = ignoreKeys; - this._stream = _stream; - } - - private async _readMetadataValue(type: keyof typeof METADATA_VALUE_TO_METHOD | 8 | 9, offset: number): Promise<{ - value: any, - newOffset: number - }> { - const numberMethod = METADATA_VALUE_TO_METHOD[type]; - if (numberMethod) { - return { // Numbers - value: await this._stream[numberMethod](offset), - newOffset: offset + METHOD_TO_BYTE_COUNT[numberMethod] - }; - } - - if (METADATA_STRING === type) { - const {string, newOffset} = await this._stream.readString(offset); - return {value: string, newOffset}; - } - - if (METADATA_ARRAY === type) { - const arrayType = await this._stream.readUint32(offset); - offset += METHOD_TO_BYTE_COUNT.readUint32; - - const arrayLength = await this._stream.readUint64(offset); - offset += METHOD_TO_BYTE_COUNT.readUint64; - - const arrayValues = []; - for (let i = 0; i < arrayLength; i++) { - const {value, newOffset} = await this._readMetadataValue(arrayType, offset); - arrayValues.push(value); - offset = newOffset; - } - return {value: arrayValues, newOffset: offset}; - } - - throw new UnsupportedMetadataTypeError(type); - } - - private async _parseMetadataRaw(): Promise<{metadata: { [key: string]: any }, metadataSize: number}> { - let offset = 0; - - const readMagicBytesLength = METHOD_TO_BYTE_COUNT.readUint8 * GGUF_MAGIC.length; - const magicBytes = await this._stream.readNBytes(readMagicBytesLength); - offset += readMagicBytesLength; - - const magicText = String.fromCharCode(...magicBytes); - if (magicText !== GGUF_MAGIC) { - throw new InvalidGGUFMagicError(); - } - - const version = await this._stream.readUint32(offset); - offset += METHOD_TO_BYTE_COUNT.readUint32; - - const tensorCount = await this._stream.readUint64(offset); - offset += METHOD_TO_BYTE_COUNT.readUint64; - - const metadataKVCount = await this._stream.readUint64(offset); - offset += METHOD_TO_BYTE_COUNT.readUint64; - - const metadata: { [key: string]: any } = { - version, - tensorCount: GGUFBaseStream.castNumber(tensorCount as bigint) - }; - - for (let i = 0; i < Number(metadataKVCount); i++) { - // read key - const keyResult = await this._stream.readString(offset); - offset = keyResult.newOffset; - - // read the value type - const valueType = await this._stream.readUint32(offset); - offset += METHOD_TO_BYTE_COUNT.readUint32; - - // read value - const valueResult = await this._readMetadataValue(valueType, offset); - offset = valueResult.newOffset; - metadata[keyResult.string] = valueResult.value; - } - - return { - metadata: metadata, - metadataSize: offset - }; - } - - public async parseMetadata(): Promise { - const metadataRaw = await this._parseMetadataRaw(); - const metadata: { [key: string]: any } = {}; - - for (const [key, value] of Object.entries(metadataRaw.metadata)) { - if (this.ignoreKeys.includes(key)) { - continue; - } - const {lastObject, lastKey} = GGUFParser._getNestedObject(key, metadata); - lastObject[lastKey] = value; - } - - if (typeof metadata?.general?.file_type === "number") { - metadata.general["file_type"] = fileTypeIntToString(metadata.general.file_type) || metadata.general.file_type; - } - - return { - metadata: metadata as GGUFMetadataAny, - metadataSize: metadataRaw.metadataSize - }; - } - - protected static _getNestedObject(key: string, currentNestedObject: any) { - const nestedKey = key.split("."); - const lastKey = nestedKey.pop()!; - - while (nestedKey.length > 0) { - const currentKey = nestedKey.shift()!; - if (!currentNestedObject[currentKey]) { - currentNestedObject[currentKey] = {}; - } - currentNestedObject = currentNestedObject[currentKey]; - } - - return { - lastObject: currentNestedObject, - lastKey - }; - } -} diff --git a/src/gguf/ggufParser/GGUFTypes.ts b/src/gguf/ggufParser/GgufMetadataTypes.ts similarity index 60% rename from src/gguf/ggufParser/GGUFTypes.ts rename to src/gguf/ggufParser/GgufMetadataTypes.ts index 1670f7e3..a4a142b9 100644 --- a/src/gguf/ggufParser/GGUFTypes.ts +++ b/src/gguf/ggufParser/GgufMetadataTypes.ts @@ -1,17 +1,39 @@ -import {fileTypeIntToString} from "./checkArchitecture.js"; - -export type GGUFArchitectureType = - | "llama" - | "falcon" - | "mpt" - | "gptneox" - | "gptj" - | "gpt2" - | "bloom" - | "rwkv" - | "whisper"; - -export type GGUFMetadataArchitectureProperties = { +export const enum GgufArchitectureType { + llama = "llama", + falcon = "falcon", + mpt = "mpt", + gptneox = "gptneox", + gptj = "gptj", + gpt2 = "gpt2", + bloom = "bloom", + rwkv = "rwkv", + whisper = "whisper" +} + +export const enum GgufFileType { + ALL_F32 = "ALL_F32", + MOSTLY_F16 = "MOSTLY_F16", + MOSTLY_Q4_0 = "MOSTLY_Q4_0", + MOSTLY_Q4_1 = "MOSTLY_Q4_1", + MOSTLY_Q4_1_SOME_F16 = "MOSTLY_Q4_1_SOME_F16", + MOSTLY_Q4_2 = "MOSTLY_Q4_2", + MOSTLY_Q4_3 = "MOSTLY_Q4_3", + MOSTLY_Q8_0 = "MOSTLY_Q8_0", + MOSTLY_Q5_0 = "MOSTLY_Q5_0", + MOSTLY_Q5_1 = "MOSTLY_Q5_1", + MOSTLY_Q2_K = "MOSTLY_Q2_K", + MOSTLY_Q3_K_S = "MOSTLY_Q3_K_S", + MOSTLY_Q3_K_M = "MOSTLY_Q3_K_M", + MOSTLY_Q3_K_L = "MOSTLY_Q3_K_L", + MOSTLY_Q4_K_S = "MOSTLY_Q4_K_S", + MOSTLY_Q4_K_M = "MOSTLY_Q4_K_M", + MOSTLY_Q5_K_S = "MOSTLY_Q5_K_S", + MOSTLY_Q5_K_M = "MOSTLY_Q5_K_M", + MOSTLY_Q6_K = "MOSTLY_Q6_K" +} + + +export type GgufMetadataArchitectureProperties = { context_length: number, embedding_length: number, block_count: number, @@ -44,8 +66,9 @@ export type GGUFMetadataArchitectureProperties = { } }; -export type GGUFMetadataGeneralProperties = { - architecture: GGUFArchitectureType, +export type GgufMetadataGeneralProperties = { + architecture: GgufArchitectureType, + /** * The version of the quantization format. Not required if the model is not * quantized (i.e. no tensors are quantized). If any tensors are quantized, @@ -82,6 +105,7 @@ export type GGUFMetadataGeneralProperties = { * covered by the other fields */ description: string, + /** * License of the model, expressed as a SPDX license expression * (e.g. `MIT OR Apache-2.0`). *Should not* include any other information, @@ -109,18 +133,18 @@ export type GGUFMetadataGeneralProperties = { * An enumerated value describing the type of the majority of the tensors * in the file. Optional; can be inferred from the tensor types. */ - file_type: ReturnType + file_type?: GgufFileType | undefined }; -export type GGUFMetadataAny = { - general: GGUFMetadataGeneralProperties +export type GgufMetadataAny = { + general: GgufMetadataGeneralProperties } & { - [key in GGUFArchitectureType]: GGUFMetadataArchitectureProperties + [key in GgufArchitectureType]: GgufMetadataArchitectureProperties }; -export type GGUFMetadataLLAMA = { - general: GGUFMetadataGeneralProperties & { - architecture: "llama" +export type GgufMetadataLLAMA = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.llama }, llama: { @@ -143,9 +167,9 @@ export type GGUFMetadataLLAMA = { } }; -export type GGUFMetadataFalcon = { - general: GGUFMetadataGeneralProperties & { - architecture: "falcon" +export type GgufMetadataFalcon = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.falcon }, falcon: { @@ -162,9 +186,9 @@ export type GGUFMetadataFalcon = { } }; -export type GGUFMetadataMPT = { - general: GGUFMetadataGeneralProperties & { - architecture: "mpt" +export type GgufMetadataMPT = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.mpt }, mpt: { @@ -180,9 +204,9 @@ export type GGUFMetadataMPT = { } }; -export type GGUFMetadataGPTNeoX = { - general: GGUFMetadataGeneralProperties & { - architecture: "gptneox" +export type GgufMetadataGPTNeoX = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.gptneox }, gptneox: { @@ -202,9 +226,9 @@ export type GGUFMetadataGPTNeoX = { } }; -export type GGUFMetadataGPTJ = { - general: GGUFMetadataGeneralProperties & { - architecture: "gptj" +export type GgufMetadataGPTJ = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.gptj }, gptj: { @@ -222,9 +246,9 @@ export type GGUFMetadataGPTJ = { } }; -export type GGUFMetadataGPT2 = { - general: GGUFMetadataGeneralProperties & { - architecture: "gpt2" +export type GgufMetadataGPT2 = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.gpt2 }, gpt2: { @@ -238,9 +262,9 @@ export type GGUFMetadataGPT2 = { } }; -export type GGUFMetadataBloom = { - general: GGUFMetadataGeneralProperties & { - architecture: "bloom" +export type GgufMetadataBloom = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.bloom }, bloom: { @@ -255,9 +279,9 @@ export type GGUFMetadataBloom = { } }; -export type GGUFMetadataRWKV = { - general: GGUFMetadataGeneralProperties & { - architecture: "rwkv" +export type GgufMetadataRWKV = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.rwkv }, rwkv: { @@ -268,9 +292,9 @@ export type GGUFMetadataRWKV = { } }; -export type GGUFMetadataWhisper = { - general: GGUFMetadataGeneralProperties & { - architecture: "whisper" +export type GgufMetadataWhisper = { + general: GgufMetadataGeneralProperties & { + architecture: GgufArchitectureType.whisper }, whisper: { encoder: { @@ -293,14 +317,46 @@ export type GGUFMetadataWhisper = { } }; +export type GgufMetadata = + | GgufMetadataLLAMA + | GgufMetadataFalcon + | GgufMetadataMPT + | GgufMetadataGPTNeoX + | GgufMetadataGPTJ + | GgufMetadataGPT2 + | GgufMetadataBloom + | GgufMetadataRWKV + | GgufMetadataWhisper; + + +export function isLlamaMetadata(metadata: GgufMetadata): metadata is GgufMetadataLLAMA { + return metadata.general.architecture === GgufArchitectureType.llama; +} + +export function isMPTMetadata(metadata: GgufMetadata): metadata is GgufMetadataMPT { + return metadata.general.architecture === GgufArchitectureType.mpt; +} + +export function isGPTNeoXMetadata(metadata: GgufMetadata): metadata is GgufMetadataGPTNeoX { + return metadata.general.architecture === GgufArchitectureType.gptneox; +} + +export function isGPTJMetadata(metadata: GgufMetadata): metadata is GgufMetadataGPTJ { + return metadata.general.architecture === GgufArchitectureType.gptj; +} + +export function isGPT2Metadata(metadata: GgufMetadata): metadata is GgufMetadataGPT2 { + return metadata.general.architecture === GgufArchitectureType.gpt2; +} + +export function isBloomMetadata(metadata: GgufMetadata): metadata is GgufMetadataBloom { + return metadata.general.architecture === GgufArchitectureType.bloom; +} + +export function isFalconMetadata(metadata: GgufMetadata): metadata is GgufMetadataFalcon { + return metadata.general.architecture === GgufArchitectureType.falcon; +} -export type GGUFMetadata = - | GGUFMetadataLLAMA - | GGUFMetadataFalcon - | GGUFMetadataMPT - | GGUFMetadataGPTNeoX - | GGUFMetadataGPTJ - | GGUFMetadataGPT2 - | GGUFMetadataBloom - | GGUFMetadataRWKV - | GGUFMetadataWhisper; +export function isRWKVMetadata(metadata: GgufMetadata): metadata is GgufMetadataRWKV { + return metadata.general.architecture === GgufArchitectureType.rwkv; +} diff --git a/src/gguf/ggufParser/GgufParser.ts b/src/gguf/ggufParser/GgufParser.ts new file mode 100644 index 00000000..8bac3baf --- /dev/null +++ b/src/gguf/ggufParser/GgufParser.ts @@ -0,0 +1,159 @@ +import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; +import UnsupportedMetadataTypeError from "../errors/UnsupportedMetadataTypeError.js"; +import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; +import {GgufReadOffset} from "./utils/GgufReadOffset.js"; +import {parseGgufFileTypeNumber} from "./utils/parseGgufFileTypeNumber.js"; +import {GgufMetadataAny} from "./GgufMetadataTypes.js"; +import {GgufBaseStream, METHOD_TO_BYTE_COUNT} from "./stream/GgufBaseStream.js"; + +const enum MetadataValueType { + Uint8 = 0, + Int8 = 1, + Uint16 = 2, + Int16 = 3, + Uint32 = 4, + Int32 = 5, + Float32 = 6, + Bool = 7, + String = 8, + Array = 9 +} + +const ggufMagic = "GGUF"; + +const defaultIgnoreMetadataKeys = [ + "tokenizer.ggml.tokens", + "tokenizer.ggml.scores", + "tokenizer.ggml.token_type", + "tokenizer.ggml.merges" +]; + +export type GGUFMetadataResponse = { + metadataSize: number, + metadata: GgufMetadataAny +}; + +export type GgufParserOptions = { + stream: GgufBaseStream, + ignoreKeys?: string[] +}; + +export class GgufParser { + private readonly _stream: GgufBaseStream; + public ignoreKeys = defaultIgnoreMetadataKeys; + + public constructor({stream, ignoreKeys = defaultIgnoreMetadataKeys}: GgufParserOptions) { + this.ignoreKeys = ignoreKeys; + this._stream = stream; + } + + public async parseMetadata({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { + const metadataRaw = await this._parseMetadataRaw(); + const metadata: { [key: string]: any } = {}; + + for (const [key, value] of Object.entries(metadataRaw.metadata)) { + if (this.ignoreKeys.includes(key)) + continue; + + const {lastObject, lastKey} = GgufParser._getNestedObject(key, metadata); + if (Object.hasOwn(lastObject, lastKey) && logWarnings) + console.warn(getConsoleLogPrefix() + `Metadata key "${key}" is already occupied by a value. Overwriting it.`); + + lastObject[lastKey] = value; + } + + if (typeof metadata?.general?.file_type === "number") { + metadata.general["file_type"] = parseGgufFileTypeNumber(metadata.general.file_type) || metadata.general.file_type; + } + + return { + metadata: metadata as GgufMetadataAny, + metadataSize: metadataRaw.metadataSize + }; + } + + private async _readMetadataValue(type: MetadataValueType, offset: number | GgufReadOffset): Promise { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + + switch (type) { + case MetadataValueType.Uint8: return await this._stream.readUint8(readOffset); + case MetadataValueType.Int8: return await this._stream.readInt8(readOffset); + case MetadataValueType.Uint16: return await this._stream.readUint16(readOffset); + case MetadataValueType.Int16: return await this._stream.readInt16(readOffset); + case MetadataValueType.Uint32: return await this._stream.readUint32(readOffset); + case MetadataValueType.Int32: return await this._stream.readInt32(readOffset); + case MetadataValueType.Float32: return await this._stream.readFloat32(readOffset); + case MetadataValueType.Bool: return await this._stream.readBool(readOffset); + case MetadataValueType.String: return await this._stream.readString(readOffset); + } + + if (type === MetadataValueType.Array) { + const arrayType = await this._stream.readUint32(readOffset); + const arrayLength = await this._stream.readUint64(readOffset); + + const arrayValues: any[] = []; + for (let i = 0; i < arrayLength; i++) { + const value = await this._readMetadataValue(arrayType, readOffset); + arrayValues.push(value); + } + return arrayValues; + } + + throw new UnsupportedMetadataTypeError(type); + } + + private async _parseMetadataRaw(): Promise<{metadata: Record, metadataSize: number}> { + const readOffset = new GgufReadOffset(0); + + const magicBytes = await this._stream.readByteRange(readOffset, METHOD_TO_BYTE_COUNT.readUint8 * ggufMagic.length); + const magicText = String.fromCharCode(...magicBytes); + + if (magicText !== ggufMagic) + throw new InvalidGgufMagicError(); + + const version = await this._stream.readUint32(readOffset); + const tensorCount = await this._stream.readUint64(readOffset); + const metadataKVCount = Number(await this._stream.readUint64(readOffset)); + + const metadata: { [key: string]: any } = { + version, + tensorCount: GgufBaseStream.castNumber(tensorCount) + }; + + for (let i = 0; i < metadataKVCount; i++) { + const keyResult = await this._stream.readString(readOffset); + const valueType = await this._stream.readUint32(readOffset); + metadata[keyResult] = await this._readMetadataValue(valueType, readOffset); + } + + return { + metadata: metadata, + metadataSize: readOffset.offset + }; + } + + private static _getNestedObject(key: string, currentNestedObject: any) { + const nestedKey = key.split("."); + const lastKey = nestedKey.pop()!; + + while (nestedKey.length > 0) { + const currentKey = nestedKey.shift()!; + if (!Object.hasOwn(currentNestedObject, currentKey)) + currentNestedObject[currentKey] = {}; + else { + const value = currentNestedObject[currentKey]; + if (value instanceof Array || value == null || typeof value !== "object") + throw new Error( + `Cannot create nested object for key "${key}". The key "${currentKey}" is already occupied by a non-object value.` + ); + } + + currentNestedObject = currentNestedObject[currentKey]; + } + + return { + lastObject: currentNestedObject, + lastKey + }; + } +} diff --git a/src/gguf/ggufParser/checkArchitecture.ts b/src/gguf/ggufParser/checkArchitecture.ts deleted file mode 100644 index 06b56501..00000000 --- a/src/gguf/ggufParser/checkArchitecture.ts +++ /dev/null @@ -1,93 +0,0 @@ -import { - GGUFMetadata, - GGUFMetadataBloom, GGUFMetadataFalcon, - GGUFMetadataGPT2, - GGUFMetadataGPTJ, - GGUFMetadataGPTNeoX, - GGUFMetadataLLAMA, - GGUFMetadataMPT, GGUFMetadataRWKV -} from "./GGUFTypes.js"; - - -export function isLlamaMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataLLAMA { - return metadata.general.architecture === "llama"; -} - -export function isMPTMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataMPT { - return metadata.general.architecture === "mpt"; -} - -export function isGPTNeoXMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataGPTNeoX { - return metadata.general.architecture === "gptneox"; -} - -export function isGPTJMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataGPTJ { - return metadata.general.architecture === "gptj"; -} - -export function isGPT2Metadata (metadata: GGUFMetadata): metadata is GGUFMetadataGPT2 { - return metadata.general.architecture === "gpt2"; -} - -export function isBloomMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataBloom { - return metadata.general.architecture === "bloom"; -} - -export function isFalconMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataFalcon { - return metadata.general.architecture === "falcon"; -} - -export function isRWKVMetadata (metadata: GGUFMetadata): metadata is GGUFMetadataRWKV { - return metadata.general.architecture === "rwkv"; -} - - -/** - * https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#general-metadata - * Convert file type from string to int - */ -export function fileTypeIntToString(fileType?: number) { - if (fileType == null) return; - switch (fileType) { - case 0: - return "ALL_F32"; - case 1: - return "MOSTLY_F16"; - case 2: - return "MOSTLY_Q4_0"; - case 3: - return "MOSTLY_Q4_1"; - case 4: - return "MOSTLY_Q4_1_SOME_F16"; - case 5: - return "MOSTLY_Q4_2"; - case 6: - return "MOSTLY_Q4_3"; - case 7: - return "MOSTLY_Q8_0"; - case 8: - return "MOSTLY_Q5_0"; - case 9: - return "MOSTLY_Q5_1"; - case 10: - return "MOSTLY_Q2_K"; - case 11: - return "MOSTLY_Q3_K_S"; - case 12: - return "MOSTLY_Q3_K_M"; - case 13: - return "MOSTLY_Q3_K_L"; - case 14: - return "MOSTLY_Q4_K_S"; - case 15: - return "MOSTLY_Q4_K_M"; - case 16: - return "MOSTLY_Q5_K_S"; - case 17: - return "MOSTLY_Q5_K_M"; - case 18: - return "MOSTLY_Q6_K"; - } - - return; -} diff --git a/src/gguf/ggufParser/stream/GGUFBaseStream.ts b/src/gguf/ggufParser/stream/GGUFBaseStream.ts deleted file mode 100644 index bb941b5b..00000000 --- a/src/gguf/ggufParser/stream/GGUFBaseStream.ts +++ /dev/null @@ -1,102 +0,0 @@ -import {Buffer} from "buffer"; - -export const METHOD_TO_BYTE_COUNT = { - readUint8: 1, - readUint16: 2, - readUint32: 4, - readUint64: 8, - readInt8: 1, - readInt16: 2, - readInt32: 4, - readInt64: 8, - readFloat32: 4, - readFloat64: 8, - readBool: 1 -}; - -export const ALLOCATION_SIZE = 1024 * 1024 * 1.5; // 1.5MB - -export default abstract class GGUFBaseStream { - protected _buffer = Buffer.alloc(0); - protected constructor() { - } - - public abstract readNBytes(numBytes: number, offset?: number): Promise; - - public async readUint8(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readUint8, offset); - return response.readUInt8(); - } - - public async readUint16(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readUint16, offset); - return response.readUInt16LE(); - } - - public async readUint32(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readUint32, offset); - return response.readUInt32LE(); - } - - public async readUint64(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readUint64, offset); - return response.readBigUInt64LE(); - } - - public async readInt8(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readInt8, offset); - return response.readInt8(); - } - - public async readInt16(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readInt16, offset); - return response.readInt16LE(); - } - - public async readInt32(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readInt32, offset); - return response.readInt32LE(); - } - - public async readInt64(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readInt64, offset); - return response.readBigInt64LE(); - } - - public async readFloat32(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readFloat32, offset); - return response.readFloatLE(); - } - - public async readFloat64(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readFloat64, offset); - return response.readDoubleLE(); - } - - public async readBool(offset: number) { - const response = await this.readNBytes(METHOD_TO_BYTE_COUNT.readUint8, offset); - return response.readUInt8() === 1; - } - - public async readString(offset: number) { - const length = Number(await this.readUint64(offset)); - offset += METHOD_TO_BYTE_COUNT.readUint64; - - const readLength = METHOD_TO_BYTE_COUNT.readUint8 * length; - const stringBytes = await this.readNBytes(readLength, offset); - - return { - string: String.fromCharCode(...stringBytes), - newOffset: offset + readLength - }; - } - - protected _addToBuffer(buffer: Buffer){ - this._buffer = Buffer.concat([this._buffer, buffer]); - } - - public static castNumber(value: bigint) { - if (value > Number.MAX_SAFE_INTEGER) return value; - return Number(value); - } -} diff --git a/src/gguf/ggufParser/stream/GGUFFetchStream.ts b/src/gguf/ggufParser/stream/GGUFFetchStream.ts deleted file mode 100644 index 9cb8f5c8..00000000 --- a/src/gguf/ggufParser/stream/GGUFFetchStream.ts +++ /dev/null @@ -1,47 +0,0 @@ -import retry from "async-retry"; -import {withLock} from "lifecycle-utils"; - -import GgufBaseStream, {ALLOCATION_SIZE} from "./GGUFBaseStream.js"; - -type GGUFFetchStreamOptions = { - retry: retry.Options -}; - -export default class GGUFFetchStream extends GgufBaseStream { - public readonly url: string; - public readonly options: Partial = {}; - - public constructor(url: string, options: Partial = {}) { - super(); - this.options = options; - this.url = url; - } - - public override async readNBytes(numBytes: number, offset = 0): Promise { - return await withLock(this, "_lock", async function readNBytesWithoutLock(): Promise { - if (offset + numBytes < this._buffer.length) { - return this._buffer.subarray(offset, offset + numBytes); - } - - const fetchMissingBytes = await retry(async () => { - return await this._fetchBytesWithoutRetry(this._buffer.length, offset + numBytes + ALLOCATION_SIZE); - }, this.options.retry); - - this._addToBuffer(fetchMissingBytes); - return await readNBytesWithoutLock.call(this); - }); - } - - private async _fetchBytesWithoutRetry(start: number, end: number) { - const response = await fetch(this.url, { - headers: { - Range: `bytes=${start}-${end}`, - accept: "*/*" - } - }); - const arrayBuffer = await response.arrayBuffer(); - return Buffer.from(arrayBuffer); - } -} - - diff --git a/src/gguf/ggufParser/stream/GGUFReadStream.ts b/src/gguf/ggufParser/stream/GGUFReadStream.ts deleted file mode 100644 index 008ceca5..00000000 --- a/src/gguf/ggufParser/stream/GGUFReadStream.ts +++ /dev/null @@ -1,50 +0,0 @@ -import fs from "node:fs/promises"; -import retry from "async-retry"; -import {withLock} from "lifecycle-utils"; -import GgufBaseStream, {ALLOCATION_SIZE} from "./GGUFBaseStream.js"; - -type GGUFReadStreamOptions = { - retry?: retry.Options, - mode: string -}; - -const DEFAULT_OPTIONS: GGUFReadStreamOptions = { - mode: "r" -}; - -export default class GGUFReadStream extends GgufBaseStream { - public readonly options: GGUFReadStreamOptions; - public readonly path: string; - - public constructor(path: string, options: Partial = {}) { - super(); - this.path = path; - this.options = {...DEFAULT_OPTIONS, ...options}; - } - - public override async readNBytes(numBytes: number, offset = 0): Promise { - return await withLock(this, "_lock", async function readNBytesWithoutLock(): Promise { - if (offset + numBytes < this._buffer.length) { - return this._buffer.subarray(offset, offset + numBytes); - } - - const readMissingBytes = await retry(async () => { - return await this._readBytesWithoutRetry(numBytes + ALLOCATION_SIZE, this._buffer.length); - }, this.options.retry); - - this._addToBuffer(readMissingBytes); - return await readNBytesWithoutLock.call(this); - }); - } - - private async _readBytesWithoutRetry(numBytes: number, offset: number) { - const fd = await fs.open(this.path, this.options.mode); - try { - const buffer = Buffer.alloc(numBytes); - await fd.read(buffer, 0, numBytes, offset); - return buffer; - } finally { - await fd.close(); - } - } -} diff --git a/src/gguf/ggufParser/stream/GgufBaseStream.ts b/src/gguf/ggufParser/stream/GgufBaseStream.ts new file mode 100644 index 00000000..d9702240 --- /dev/null +++ b/src/gguf/ggufParser/stream/GgufBaseStream.ts @@ -0,0 +1,106 @@ +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; + +export const METHOD_TO_BYTE_COUNT = { + readUint8: 1, + readUint16: 2, + readUint32: 4, + readUint64: 8, + readInt8: 1, + readInt16: 2, + readInt32: 4, + readInt64: 8, + readFloat32: 4, + readFloat64: 8, + readBool: 1 +} as const; + +export const ALLOCATION_SIZE = 1024 * 1024 * 1.5; // 1.5MB + +export abstract class GgufBaseStream { + protected _buffer = Buffer.alloc(0); + + public abstract readByteRange(offset: number | GgufReadOffset, length: number): Promise; + + public async readUint8(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint8); + return response.readUInt8(); + } + + public async readUint16(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint16); + return response.readUInt16LE(); + } + + public async readUint32(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint32); + return response.readUInt32LE(); + } + + public async readUint64(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint64); + return response.readBigUInt64LE(); + } + + public async readInt8(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt8); + return response.readInt8(); + } + + public async readInt16(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt16); + return response.readInt16LE(); + } + + public async readInt32(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt32); + return response.readInt32LE(); + } + + public async readInt64(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt64); + return response.readBigInt64LE(); + } + + public async readFloat32(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readFloat32); + return response.readFloatLE(); + } + + public async readFloat64(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readFloat64); + return response.readDoubleLE(); + } + + public async readBool(offset: number | GgufReadOffset) { + const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint8); + return response.readUInt8() === 1; + } + + public async readString(offset: number | GgufReadOffset) { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + const length = Number(await this.readUint64(readOffset)); + + const readLength = METHOD_TO_BYTE_COUNT.readUint8 * length; + const stringBytes = await this.readByteRange(readOffset, readLength); + + return String.fromCharCode(...stringBytes); + } + + protected _addToBuffer(buffer: Buffer){ + this._buffer = Buffer.concat([this._buffer, buffer]); + } + + private async _readByteRangeAndUpdateOffset(offset: number | GgufReadOffset, length: number) { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + + const response = await this.readByteRange(readOffset.offset, length); + readOffset.moveBy(length); + + return response; + } + + public static castNumber(value: bigint) { + if (value > Number.MAX_SAFE_INTEGER) return value; + return Number(value); + } +} diff --git a/src/gguf/ggufParser/stream/GgufFetchStream.ts b/src/gguf/ggufParser/stream/GgufFetchStream.ts new file mode 100644 index 00000000..f4bd1fd1 --- /dev/null +++ b/src/gguf/ggufParser/stream/GgufFetchStream.ts @@ -0,0 +1,62 @@ +import retry from "async-retry"; +import {withLock} from "lifecycle-utils"; +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import {GgufBaseStream, ALLOCATION_SIZE} from "./GgufBaseStream.js"; + +type GgufFetchStreamOptions = { + url: string, + retryOptions?: retry.Options +}; + +const defaultRetryOptions: retry.Options = { + retries: 10, + factor: 2, + minTimeout: 1000, + maxTimeout: 1000 * 16 +} as const; + +export class GgufFetchStream extends GgufBaseStream { + public readonly url: string; + public readonly retryOptions: retry.Options; + + public constructor({url, retryOptions = defaultRetryOptions}: GgufFetchStreamOptions) { + super(); + this.url = url; + this.retryOptions = retryOptions; + } + + public async readByteRange(offset: number | GgufReadOffset, length: number) { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + const endOffset = readOffset.offset + length; + + if (endOffset >= this._buffer.length) + await this._fetchToExpandBufferUpToOffset(endOffset); + + const res = this._buffer.subarray(readOffset.offset, endOffset); + readOffset.moveBy(length); + return res; + } + + private async _fetchToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = ALLOCATION_SIZE) { + await withLock(this, "modifyBuffer", async () => { + if (endOffset < this._buffer.length) + return; + + const missingBytesBuffer = await retry(async () => { + return await this._fetchByteRange(this._buffer.length, endOffset + extraAllocationSize - this._buffer.length); + }, this.retryOptions); + this._addToBuffer(missingBytesBuffer); + }); + } + + private async _fetchByteRange(start: number, length: number) { + const response = await fetch(this.url, { + headers: { + Range: `bytes=${start}-${start + length}`, + accept: "*/*" + } + }); + const arrayBuffer = await response.arrayBuffer(); + return Buffer.from(arrayBuffer); + } +} diff --git a/src/gguf/ggufParser/stream/GgufFsReadStream.ts b/src/gguf/ggufParser/stream/GgufFsReadStream.ts new file mode 100644 index 00000000..83e074b4 --- /dev/null +++ b/src/gguf/ggufParser/stream/GgufFsReadStream.ts @@ -0,0 +1,64 @@ +import fs from "node:fs/promises"; +import retry from "async-retry"; +import {withLock} from "lifecycle-utils"; +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import {GgufBaseStream, ALLOCATION_SIZE} from "./GgufBaseStream.js"; + +type GgufReadStreamOptions = { + filePath: string, + retryOptions?: retry.Options +}; + +const defaultRetryOptions: retry.Options = { + retries: 10, + factor: 2, + minTimeout: 1000, + maxTimeout: 1000 * 16 +} as const; + +export class GgufFsReadStream extends GgufBaseStream { + public readonly filePath: string; + public readonly retryOptions: retry.Options; + + public constructor({filePath, retryOptions = defaultRetryOptions}: GgufReadStreamOptions) { + super(); + this.filePath = filePath; + this.retryOptions = retryOptions; + } + + public async readByteRange(offset: number | GgufReadOffset, length: number) { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + const endOffset = readOffset.offset + length; + + if (endOffset >= this._buffer.length) + await this._readToExpandBufferUpToOffset(endOffset); + + const res = this._buffer.subarray(readOffset.offset, endOffset); + readOffset.moveBy(length); + return res; + } + + private async _readToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = ALLOCATION_SIZE) { + return await withLock(this, "modifyBuffer", async () => { + if (endOffset < this._buffer.length) + return; + + const missingBytesBuffer = await retry(async () => { + return await this._readByteRange(this._buffer.length, endOffset + extraAllocationSize - this._buffer.length); + }, this.retryOptions); + + this._addToBuffer(missingBytesBuffer); + }); + } + + private async _readByteRange(start: number, length: number) { + const fd = await fs.open(this.filePath, "r"); + try { + const buffer = Buffer.alloc(length); + await fd.read(buffer, 0, length, start); + return buffer; + } finally { + await fd.close(); + } + } +} diff --git a/src/gguf/ggufParser/utils/GgufReadOffset.ts b/src/gguf/ggufParser/utils/GgufReadOffset.ts new file mode 100644 index 00000000..4d158ca8 --- /dev/null +++ b/src/gguf/ggufParser/utils/GgufReadOffset.ts @@ -0,0 +1,21 @@ +export class GgufReadOffset { + public offset: number; + + public constructor(offset: number | GgufReadOffset) { + if (offset instanceof GgufReadOffset) + this.offset = offset.offset; + else + this.offset = offset; + } + + public moveBy(amount: number) { + this.offset += amount; + } + + public static resolveReadOffset(offset: number | GgufReadOffset) { + if (offset instanceof GgufReadOffset) + return offset; + + return new GgufReadOffset(offset); + } +} diff --git a/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts b/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts new file mode 100644 index 00000000..1baa7ef6 --- /dev/null +++ b/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts @@ -0,0 +1,34 @@ +import {GgufFileType} from "../GgufMetadataTypes.js"; + +/** + * https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#general-metadata + * Convert file type from string to int + */ +export function parseGgufFileTypeNumber(fileType?: number) { + if (fileType == null) + return undefined; + + switch (fileType) { + case 0: return GgufFileType.ALL_F32; + case 1: return GgufFileType.MOSTLY_F16; + case 2: return GgufFileType.MOSTLY_Q4_0; + case 3: return GgufFileType.MOSTLY_Q4_1; + case 4: return GgufFileType.MOSTLY_Q4_1_SOME_F16; + case 5: return GgufFileType.MOSTLY_Q4_2; + case 6: return GgufFileType.MOSTLY_Q4_3; + case 7: return GgufFileType.MOSTLY_Q8_0; + case 8: return GgufFileType.MOSTLY_Q5_0; + case 9: return GgufFileType.MOSTLY_Q5_1; + case 10: return GgufFileType.MOSTLY_Q2_K; + case 11: return GgufFileType.MOSTLY_Q3_K_S; + case 12: return GgufFileType.MOSTLY_Q3_K_M; + case 13: return GgufFileType.MOSTLY_Q3_K_L; + case 14: return GgufFileType.MOSTLY_Q4_K_S; + case 15: return GgufFileType.MOSTLY_Q4_K_M; + case 16: return GgufFileType.MOSTLY_Q5_K_S; + case 17: return GgufFileType.MOSTLY_Q5_K_M; + case 18: return GgufFileType.MOSTLY_Q6_K; + } + + return undefined; +} diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index f519bcc3..f840641b 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,6 +1,6 @@ import {describe, expect, it, test} from "vitest"; -import GGUFReadStream from "../../../src/gguf/ggufParser/stream/GGUFReadStream.js"; -import GGUFParser from "../../../src/gguf/ggufParser/GGUFParser.js"; +import {GgufFsReadStream} from "../../../src/gguf/ggufParser/stream/GgufFsReadStream.js"; +import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; import {getModelFile} from "../../utils/modelFiles.js"; import GGUFInsights from "../../../src/gguf/GGUFInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; @@ -10,26 +10,26 @@ describe("GGUF Parser", async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); test("Magic should be GGUF local model", async () => { - const stream = new GGUFReadStream(modelPath); - const magic = await stream.readNBytes(4); + const stream = new GgufFsReadStream({filePath: modelPath}); + const magic = await stream.readByteRange(0, 4); const magicText = String.fromCharCode(...magic); expect(magicText).toBe("GGUF"); }); it("should parse local gguf model", async () => { - const stream = new GGUFReadStream(modelPath); + const stream = new GgufFsReadStream({filePath: modelPath}); + const ggufParser = new GgufParser({stream}); - const ggufParser = new GGUFParser(stream); const metadata = await ggufParser.parseMetadata(); expect(metadata).toMatchSnapshot(); }); it("should calculate GGUF VRAM Usage", async () => { - const stream = new GGUFReadStream(modelPath); + const stream = new GgufFsReadStream({filePath: modelPath}); + const ggufParser = new GgufParser({stream}); - const ggufParser = new GGUFParser(stream); const metadata = await ggufParser.parseMetadata(); const ggufInsights = new GGUFInsights(metadata); diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 81070908..8a86e40e 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,23 +1,23 @@ import {describe, expect, it, test} from "vitest"; -import GGUFParser from "../../../src/gguf/ggufParser/GGUFParser.js"; -import GGUFFetchStream from "../../../src/gguf/ggufParser/stream/GGUFFetchStream.js"; +import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; +import {GgufFetchStream} from "../../../src/gguf/ggufParser/stream/GgufFetchStream.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; describe("GGUF Parser", async () => { test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { - const stream = new GGUFFetchStream(remoteGGUFModel); + const stream = new GgufFetchStream({url: remoteGGUFModel}); - const magic = await stream.readNBytes(4); + const magic = await stream.readByteRange(0, 4); const magicText = String.fromCharCode(...magic); expect(magicText).toBe("GGUF"); }); it("should parse remote gguf model", async () => { - const stream = new GGUFFetchStream(remoteGGUFModel); + const stream = new GgufFetchStream({url: remoteGGUFModel}); + const ggufParser = new GgufParser({stream}); - const ggufParser = new GGUFParser(stream); const metadata = await ggufParser.parseMetadata(); expect(metadata).toMatchSnapshot(); From 34a4c501d18b02a6490d88a355bbfd7528d4e66a Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 19 Mar 2024 01:07:39 +0200 Subject: [PATCH 09/52] refactor: rename `stream` to `fileReader` --- src/gguf/GGUFMetadata.ts | 14 +++--- src/gguf/ggufParser/GgufParser.ts | 46 +++++++++---------- .../GgufBaseFileReader.ts} | 2 +- .../GgufFetchFileReader.ts} | 8 ++-- .../GgufFsFileReader.ts} | 8 ++-- test/modelDependent/functionary/gguf.test.ts | 14 +++--- test/standalone/gguf/gguf.test.ts | 10 ++-- 7 files changed, 51 insertions(+), 51 deletions(-) rename src/gguf/ggufParser/{stream/GgufBaseStream.ts => fileReaders/GgufBaseFileReader.ts} (98%) rename src/gguf/ggufParser/{stream/GgufFetchStream.ts => fileReaders/GgufFetchFileReader.ts} (90%) rename src/gguf/ggufParser/{stream/GgufFsReadStream.ts => fileReaders/GgufFsFileReader.ts} (90%) diff --git a/src/gguf/GGUFMetadata.ts b/src/gguf/GGUFMetadata.ts index c8d5796d..5dcb13f9 100644 --- a/src/gguf/GGUFMetadata.ts +++ b/src/gguf/GGUFMetadata.ts @@ -2,8 +2,8 @@ import retry from "async-retry"; import MetadataNotParsedYetError from "./errors/MetadataNotParsedYetError.js"; import GGUFInsights, {GGUFInsightsOptions} from "./GGUFInsights.js"; import {GgufParser, GGUFMetadataResponse} from "./ggufParser/GgufParser.js"; -import {GgufFetchStream} from "./ggufParser/stream/GgufFetchStream.js"; -import {GgufFsReadStream} from "./ggufParser/stream/GgufFsReadStream.js"; +import {GgufFetchFileReader} from "./ggufParser/fileReaders/GgufFetchFileReader.js"; +import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; export type GGUFMetadataOptions = { source?: "network" | "local", @@ -34,24 +34,24 @@ export default class GGUFMetadata { } public async parse() { - const stream = this._createStream(); + const fileReader = this._createFileReader(); const parser = new GgufParser({ - stream, + fileReader, ignoreKeys: this.options.ignoreKeys }); return this._metadata = await parser.parseMetadata(); } - private _createStream() { + private _createFileReader() { switch (this.options.source) { case "network": - return new GgufFetchStream({ + return new GgufFetchFileReader({ url: this.path, retryOptions: this.options.retry }); case "local": default: - return new GgufFsReadStream({ + return new GgufFsFileReader({ filePath: this.path, retryOptions: this.options.retry }); diff --git a/src/gguf/ggufParser/GgufParser.ts b/src/gguf/ggufParser/GgufParser.ts index 8bac3baf..1eb89d46 100644 --- a/src/gguf/ggufParser/GgufParser.ts +++ b/src/gguf/ggufParser/GgufParser.ts @@ -4,7 +4,7 @@ import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; import {GgufReadOffset} from "./utils/GgufReadOffset.js"; import {parseGgufFileTypeNumber} from "./utils/parseGgufFileTypeNumber.js"; import {GgufMetadataAny} from "./GgufMetadataTypes.js"; -import {GgufBaseStream, METHOD_TO_BYTE_COUNT} from "./stream/GgufBaseStream.js"; +import {GgufBaseFileReader, METHOD_TO_BYTE_COUNT} from "./fileReaders/GgufBaseFileReader.js"; const enum MetadataValueType { Uint8 = 0, @@ -34,17 +34,17 @@ export type GGUFMetadataResponse = { }; export type GgufParserOptions = { - stream: GgufBaseStream, + fileReader: GgufBaseFileReader, ignoreKeys?: string[] }; export class GgufParser { - private readonly _stream: GgufBaseStream; + private readonly _fileReader: GgufBaseFileReader; public ignoreKeys = defaultIgnoreMetadataKeys; - public constructor({stream, ignoreKeys = defaultIgnoreMetadataKeys}: GgufParserOptions) { + public constructor({fileReader, ignoreKeys = defaultIgnoreMetadataKeys}: GgufParserOptions) { this.ignoreKeys = ignoreKeys; - this._stream = stream; + this._fileReader = fileReader; } public async parseMetadata({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { @@ -76,20 +76,20 @@ export class GgufParser { const readOffset = GgufReadOffset.resolveReadOffset(offset); switch (type) { - case MetadataValueType.Uint8: return await this._stream.readUint8(readOffset); - case MetadataValueType.Int8: return await this._stream.readInt8(readOffset); - case MetadataValueType.Uint16: return await this._stream.readUint16(readOffset); - case MetadataValueType.Int16: return await this._stream.readInt16(readOffset); - case MetadataValueType.Uint32: return await this._stream.readUint32(readOffset); - case MetadataValueType.Int32: return await this._stream.readInt32(readOffset); - case MetadataValueType.Float32: return await this._stream.readFloat32(readOffset); - case MetadataValueType.Bool: return await this._stream.readBool(readOffset); - case MetadataValueType.String: return await this._stream.readString(readOffset); + case MetadataValueType.Uint8: return await this._fileReader.readUint8(readOffset); + case MetadataValueType.Int8: return await this._fileReader.readInt8(readOffset); + case MetadataValueType.Uint16: return await this._fileReader.readUint16(readOffset); + case MetadataValueType.Int16: return await this._fileReader.readInt16(readOffset); + case MetadataValueType.Uint32: return await this._fileReader.readUint32(readOffset); + case MetadataValueType.Int32: return await this._fileReader.readInt32(readOffset); + case MetadataValueType.Float32: return await this._fileReader.readFloat32(readOffset); + case MetadataValueType.Bool: return await this._fileReader.readBool(readOffset); + case MetadataValueType.String: return await this._fileReader.readString(readOffset); } if (type === MetadataValueType.Array) { - const arrayType = await this._stream.readUint32(readOffset); - const arrayLength = await this._stream.readUint64(readOffset); + const arrayType = await this._fileReader.readUint32(readOffset); + const arrayLength = await this._fileReader.readUint64(readOffset); const arrayValues: any[] = []; for (let i = 0; i < arrayLength; i++) { @@ -105,24 +105,24 @@ export class GgufParser { private async _parseMetadataRaw(): Promise<{metadata: Record, metadataSize: number}> { const readOffset = new GgufReadOffset(0); - const magicBytes = await this._stream.readByteRange(readOffset, METHOD_TO_BYTE_COUNT.readUint8 * ggufMagic.length); + const magicBytes = await this._fileReader.readByteRange(readOffset, METHOD_TO_BYTE_COUNT.readUint8 * ggufMagic.length); const magicText = String.fromCharCode(...magicBytes); if (magicText !== ggufMagic) throw new InvalidGgufMagicError(); - const version = await this._stream.readUint32(readOffset); - const tensorCount = await this._stream.readUint64(readOffset); - const metadataKVCount = Number(await this._stream.readUint64(readOffset)); + const version = await this._fileReader.readUint32(readOffset); + const tensorCount = await this._fileReader.readUint64(readOffset); + const metadataKVCount = Number(await this._fileReader.readUint64(readOffset)); const metadata: { [key: string]: any } = { version, - tensorCount: GgufBaseStream.castNumber(tensorCount) + tensorCount: GgufBaseFileReader.castNumber(tensorCount) }; for (let i = 0; i < metadataKVCount; i++) { - const keyResult = await this._stream.readString(readOffset); - const valueType = await this._stream.readUint32(readOffset); + const keyResult = await this._fileReader.readString(readOffset); + const valueType = await this._fileReader.readUint32(readOffset); metadata[keyResult] = await this._readMetadataValue(valueType, readOffset); } diff --git a/src/gguf/ggufParser/stream/GgufBaseStream.ts b/src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts similarity index 98% rename from src/gguf/ggufParser/stream/GgufBaseStream.ts rename to src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts index d9702240..9c11cf8b 100644 --- a/src/gguf/ggufParser/stream/GgufBaseStream.ts +++ b/src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts @@ -16,7 +16,7 @@ export const METHOD_TO_BYTE_COUNT = { export const ALLOCATION_SIZE = 1024 * 1024 * 1.5; // 1.5MB -export abstract class GgufBaseStream { +export abstract class GgufBaseFileReader { protected _buffer = Buffer.alloc(0); public abstract readByteRange(offset: number | GgufReadOffset, length: number): Promise; diff --git a/src/gguf/ggufParser/stream/GgufFetchStream.ts b/src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts similarity index 90% rename from src/gguf/ggufParser/stream/GgufFetchStream.ts rename to src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts index f4bd1fd1..b64a091c 100644 --- a/src/gguf/ggufParser/stream/GgufFetchStream.ts +++ b/src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts @@ -1,9 +1,9 @@ import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {GgufBaseStream, ALLOCATION_SIZE} from "./GgufBaseStream.js"; +import {GgufBaseFileReader, ALLOCATION_SIZE} from "./GgufBaseFileReader.js"; -type GgufFetchStreamOptions = { +type GgufFetchFileReaderOptions = { url: string, retryOptions?: retry.Options }; @@ -15,11 +15,11 @@ const defaultRetryOptions: retry.Options = { maxTimeout: 1000 * 16 } as const; -export class GgufFetchStream extends GgufBaseStream { +export class GgufFetchFileReader extends GgufBaseFileReader { public readonly url: string; public readonly retryOptions: retry.Options; - public constructor({url, retryOptions = defaultRetryOptions}: GgufFetchStreamOptions) { + public constructor({url, retryOptions = defaultRetryOptions}: GgufFetchFileReaderOptions) { super(); this.url = url; this.retryOptions = retryOptions; diff --git a/src/gguf/ggufParser/stream/GgufFsReadStream.ts b/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts similarity index 90% rename from src/gguf/ggufParser/stream/GgufFsReadStream.ts rename to src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts index 83e074b4..23004c4a 100644 --- a/src/gguf/ggufParser/stream/GgufFsReadStream.ts +++ b/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts @@ -2,9 +2,9 @@ import fs from "node:fs/promises"; import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {GgufBaseStream, ALLOCATION_SIZE} from "./GgufBaseStream.js"; +import {GgufBaseFileReader, ALLOCATION_SIZE} from "./GgufBaseFileReader.js"; -type GgufReadStreamOptions = { +type GgufFsFileReaderOptions = { filePath: string, retryOptions?: retry.Options }; @@ -16,11 +16,11 @@ const defaultRetryOptions: retry.Options = { maxTimeout: 1000 * 16 } as const; -export class GgufFsReadStream extends GgufBaseStream { +export class GgufFsFileReader extends GgufBaseFileReader { public readonly filePath: string; public readonly retryOptions: retry.Options; - public constructor({filePath, retryOptions = defaultRetryOptions}: GgufReadStreamOptions) { + public constructor({filePath, retryOptions = defaultRetryOptions}: GgufFsFileReaderOptions) { super(); this.filePath = filePath; this.retryOptions = retryOptions; diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index f840641b..961ee842 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,5 +1,5 @@ import {describe, expect, it, test} from "vitest"; -import {GgufFsReadStream} from "../../../src/gguf/ggufParser/stream/GgufFsReadStream.js"; +import {GgufFsFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFsFileReader.js"; import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; import {getModelFile} from "../../utils/modelFiles.js"; import GGUFInsights from "../../../src/gguf/GGUFInsights.js"; @@ -10,16 +10,16 @@ describe("GGUF Parser", async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); test("Magic should be GGUF local model", async () => { - const stream = new GgufFsReadStream({filePath: modelPath}); - const magic = await stream.readByteRange(0, 4); + const fileReader = new GgufFsFileReader({filePath: modelPath}); + const magic = await fileReader.readByteRange(0, 4); const magicText = String.fromCharCode(...magic); expect(magicText).toBe("GGUF"); }); it("should parse local gguf model", async () => { - const stream = new GgufFsReadStream({filePath: modelPath}); - const ggufParser = new GgufParser({stream}); + const fileReader = new GgufFsFileReader({filePath: modelPath}); + const ggufParser = new GgufParser({fileReader: fileReader}); const metadata = await ggufParser.parseMetadata(); @@ -27,8 +27,8 @@ describe("GGUF Parser", async () => { }); it("should calculate GGUF VRAM Usage", async () => { - const stream = new GgufFsReadStream({filePath: modelPath}); - const ggufParser = new GgufParser({stream}); + const fileReader = new GgufFsFileReader({filePath: modelPath}); + const ggufParser = new GgufParser({fileReader: fileReader}); const metadata = await ggufParser.parseMetadata(); diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 8a86e40e..15fc14e0 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,22 +1,22 @@ import {describe, expect, it, test} from "vitest"; import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; -import {GgufFetchStream} from "../../../src/gguf/ggufParser/stream/GgufFetchStream.js"; +import {GgufFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFetchFileReader.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; describe("GGUF Parser", async () => { test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { - const stream = new GgufFetchStream({url: remoteGGUFModel}); + const fileReader = new GgufFetchFileReader({url: remoteGGUFModel}); - const magic = await stream.readByteRange(0, 4); + const magic = await fileReader.readByteRange(0, 4); const magicText = String.fromCharCode(...magic); expect(magicText).toBe("GGUF"); }); it("should parse remote gguf model", async () => { - const stream = new GgufFetchStream({url: remoteGGUFModel}); - const ggufParser = new GgufParser({stream}); + const fileReader = new GgufFetchFileReader({url: remoteGGUFModel}); + const ggufParser = new GgufParser({fileReader: fileReader}); const metadata = await ggufParser.parseMetadata(); From 69466ae4ceb45748454d7119dfb3f29faa5f43c9 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 19 Mar 2024 01:23:33 +0200 Subject: [PATCH 10/52] refactor: gguf --- src/gguf/GGUFInsights.ts | 83 ------------------- src/gguf/GGUFMetadata.ts | 60 -------------- src/gguf/GgufInsights.ts | 45 ++++++++++ src/gguf/consts.ts | 10 +++ src/gguf/errors/InvalidGgufMagicError.ts | 4 +- src/gguf/errors/MetadataNotParsedYetError.ts | 5 -- src/gguf/errors/MissingNodeLlamaError.ts | 5 -- .../errors/ModelScore/NotEnoughVRamError.ts | 12 --- .../errors/UnsupportedMetadataTypeError.ts | 11 ++- src/gguf/ggufParser/GgufParser.ts | 27 +++--- ...gufBaseFileReader.ts => GgufFileReader.ts} | 52 ++++++------ .../fileReaders/GgufFsFileReader.ts | 16 ++-- ...eader.ts => GgufNetworkFetchFileReader.ts} | 26 +++--- src/gguf/parseGgufMetadata.ts | 41 +++++++++ src/utils/spawnCommand.ts | 6 +- test/modelDependent/functionary/gguf.test.ts | 15 ++-- test/standalone/gguf/gguf.test.ts | 6 +- 17 files changed, 177 insertions(+), 247 deletions(-) delete mode 100644 src/gguf/GGUFInsights.ts delete mode 100644 src/gguf/GGUFMetadata.ts create mode 100644 src/gguf/GgufInsights.ts create mode 100644 src/gguf/consts.ts delete mode 100644 src/gguf/errors/MetadataNotParsedYetError.ts delete mode 100644 src/gguf/errors/MissingNodeLlamaError.ts delete mode 100644 src/gguf/errors/ModelScore/NotEnoughVRamError.ts rename src/gguf/ggufParser/fileReaders/{GgufBaseFileReader.ts => GgufFileReader.ts} (77%) rename src/gguf/ggufParser/fileReaders/{GgufFetchFileReader.ts => GgufNetworkFetchFileReader.ts} (70%) create mode 100644 src/gguf/parseGgufMetadata.ts diff --git a/src/gguf/GGUFInsights.ts b/src/gguf/GGUFInsights.ts deleted file mode 100644 index b8a33467..00000000 --- a/src/gguf/GGUFInsights.ts +++ /dev/null @@ -1,83 +0,0 @@ -import {Llama} from "../bindings/Llama.js"; -import MissingNodeLlamaError from "./errors/MissingNodeLlamaError.js"; -import {GGUFMetadataResponse} from "./ggufParser/GgufParser.js"; -import NotEnoughVRamError from "./errors/ModelScore/NotEnoughVRamError.js"; - -const PAD_AVAILABLE_VRAM = 1024 ** 2 * 500; // 500MB - -export type GGUFInsightsOptions = { - contextCount?: number, - nodeLlama?: Llama, - modelSize?: number -}; - -export default class GGUFInsights { - public readonly metadataResponse: GGUFMetadataResponse; - public readonly options: GGUFInsightsOptions = {}; - - public get metadata() { - return this.metadataResponse.metadata; - } - - public get architectureMetadata() { - return this.metadata[this.metadata.general.architecture]; - } - - /** - * fp16 k,v matrices - */ - public get kvMatrices(){ - // 2 bytes each * 2 key and value - return ( - 2 * 2 * - this.architectureMetadata.context_length * - this.architectureMetadata.block_count * - this.architectureMetadata.embedding_length * - this.architectureMetadata.attention.head_count_kv / - this.architectureMetadata.attention.head_count - ); - } - - /** - * This amount is the overhead + tensors in memory - */ - public get graphSize() { - // TODO: get this from the llama.cpp's graph calculations instead of - // estimating it's 1/6 * kv_cache_size * num_gqa - return ( - (this.architectureMetadata.attention.head_count_kv / - this.architectureMetadata.attention.head_count) * this.kvMatrices / 6 - ); - } - - public get VRAMUsage(){ - return this.graphSize + this.kvMatrices + this.metadataResponse.metadataSize; - } - - protected get _availableVRam(){ - if (!this.options?.nodeLlama){ - throw new MissingNodeLlamaError("GGUFInsights Calculations"); - } - return this.options.nodeLlama.getVramState().total - PAD_AVAILABLE_VRAM; - } - - public constructor(metadataResponse: GGUFMetadataResponse, options: GGUFInsightsOptions = {}) { - this.options = options; - this.metadataResponse = metadataResponse; - - } - - - /** - * The score of the model by how much it's compatible to the current system - */ - public modelScore(){ - const vramScore = this.VRAMUsage / this._availableVRam; - if (vramScore >= 1){ - throw new NotEnoughVRamError(this.VRAMUsage, this._availableVRam); - } - - return vramScore; - } - -} diff --git a/src/gguf/GGUFMetadata.ts b/src/gguf/GGUFMetadata.ts deleted file mode 100644 index 5dcb13f9..00000000 --- a/src/gguf/GGUFMetadata.ts +++ /dev/null @@ -1,60 +0,0 @@ -import retry from "async-retry"; -import MetadataNotParsedYetError from "./errors/MetadataNotParsedYetError.js"; -import GGUFInsights, {GGUFInsightsOptions} from "./GGUFInsights.js"; -import {GgufParser, GGUFMetadataResponse} from "./ggufParser/GgufParser.js"; -import {GgufFetchFileReader} from "./ggufParser/fileReaders/GgufFetchFileReader.js"; -import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; - -export type GGUFMetadataOptions = { - source?: "network" | "local", - retry?: retry.Options, - ignoreKeys?: string[], - insights?: GGUFInsightsOptions -}; - -export default class GGUFMetadata { - protected _metadata?: GGUFMetadataResponse; - public readonly path: string; - public readonly options: Partial = {}; - - public constructor(path: string, options: Partial = {}) { - this.options = options; - this.path = path; - } - - public get metadata() { - if (!this._metadata) { - throw new MetadataNotParsedYetError(this.path); - } - return this._metadata; - } - - public get insights(){ - return new GGUFInsights(this.metadata, this.options.insights); - } - - public async parse() { - const fileReader = this._createFileReader(); - const parser = new GgufParser({ - fileReader, - ignoreKeys: this.options.ignoreKeys - }); - return this._metadata = await parser.parseMetadata(); - } - - private _createFileReader() { - switch (this.options.source) { - case "network": - return new GgufFetchFileReader({ - url: this.path, - retryOptions: this.options.retry - }); - case "local": - default: - return new GgufFsFileReader({ - filePath: this.path, - retryOptions: this.options.retry - }); - } - } -} diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts new file mode 100644 index 00000000..0ba9e8b4 --- /dev/null +++ b/src/gguf/GgufInsights.ts @@ -0,0 +1,45 @@ +import {GgufParsedMetadataResult} from "./ggufParser/GgufParser.js"; + +export class GgufInsights { + public readonly metadataResponse: GgufParsedMetadataResult; + + public constructor(metadataResponse: GgufParsedMetadataResult) { + this.metadataResponse = metadataResponse; + } + + public get metadata() { + return this.metadataResponse.metadata; + } + + public get architectureMetadata() { + return this.metadata[this.metadata.general.architecture]; + } + + /** + * fp16 k,v matrices + */ + public get kvMatrices() { + // 2 bytes each * 2 key and value + return ( + 2 * 2 * + this.architectureMetadata.context_length * + this.architectureMetadata.block_count * + this.architectureMetadata.embedding_length * + this.architectureMetadata.attention.head_count_kv / + this.architectureMetadata.attention.head_count + ); + } + + /** + * This amount is the overhead + tensors in memory + */ + public get graphSize() { + // TODO: get this from the llama.cpp's graph calculations instead of + // estimating it's 1/6 * kv_cache_size * num_gqa + return (this.architectureMetadata.attention.head_count_kv / this.architectureMetadata.attention.head_count) * this.kvMatrices / 6; + } + + public get VRAMUsage() { + return this.graphSize + this.kvMatrices + this.metadataResponse.metadataSize; + } +} diff --git a/src/gguf/consts.ts b/src/gguf/consts.ts new file mode 100644 index 00000000..9140b0b3 --- /dev/null +++ b/src/gguf/consts.ts @@ -0,0 +1,10 @@ +import retry from "async-retry"; + +export const ggufDefaultRetryOptions: retry.Options = { + retries: 10, + factor: 2, + minTimeout: 1000, + maxTimeout: 1000 * 16 +} as const; + +export const defaultExtraAllocationSize = 1024 * 1024 * 1.5; // 1.5MB diff --git a/src/gguf/errors/InvalidGgufMagicError.ts b/src/gguf/errors/InvalidGgufMagicError.ts index 0efe3ef6..2225da45 100644 --- a/src/gguf/errors/InvalidGgufMagicError.ts +++ b/src/gguf/errors/InvalidGgufMagicError.ts @@ -1,5 +1,5 @@ export class InvalidGgufMagicError extends Error { - public constructor(message = "Invalid GGUF magic") { - super(message); + public constructor(expectedGgufMagic: string, actualGgufMagic: string) { + super(`Invalid GGUF magic. Expected "${expectedGgufMagic}" but got "${actualGgufMagic}".`); } } diff --git a/src/gguf/errors/MetadataNotParsedYetError.ts b/src/gguf/errors/MetadataNotParsedYetError.ts deleted file mode 100644 index c0d17cf2..00000000 --- a/src/gguf/errors/MetadataNotParsedYetError.ts +++ /dev/null @@ -1,5 +0,0 @@ -export default class MetadataNotParsedYetError extends Error { - public constructor(path: string) { - super(`Metadata not parsed yet: "${path}"`); - } -} diff --git a/src/gguf/errors/MissingNodeLlamaError.ts b/src/gguf/errors/MissingNodeLlamaError.ts deleted file mode 100644 index 1e999b07..00000000 --- a/src/gguf/errors/MissingNodeLlamaError.ts +++ /dev/null @@ -1,5 +0,0 @@ -export default class MissingNodeLlamaError extends Error { - public constructor(purpose: string) { - super(`Missing nodeLlama options, this in required for ${purpose}`); - } -} diff --git a/src/gguf/errors/ModelScore/NotEnoughVRamError.ts b/src/gguf/errors/ModelScore/NotEnoughVRamError.ts deleted file mode 100644 index b5fe5f32..00000000 --- a/src/gguf/errors/ModelScore/NotEnoughVRamError.ts +++ /dev/null @@ -1,12 +0,0 @@ -import bytes from "bytes"; - -export default class NotEnoughVRamError extends Error { - public readonly requiredVRAM: number; - public readonly availableVRAM: number; - - public constructor(requiredVRAM: number, availableVRAM: number) { - super(`${bytes(requiredVRAM)} of VRAM is required, but only ${bytes(availableVRAM)} is available`); - this.availableVRAM = availableVRAM; - this.requiredVRAM = requiredVRAM; - } -} diff --git a/src/gguf/errors/UnsupportedMetadataTypeError.ts b/src/gguf/errors/UnsupportedMetadataTypeError.ts index 65344239..3be0a8b9 100644 --- a/src/gguf/errors/UnsupportedMetadataTypeError.ts +++ b/src/gguf/errors/UnsupportedMetadataTypeError.ts @@ -1,8 +1,11 @@ export default class UnsupportedMetadataTypeError extends Error { - public readonly type: number; + public readonly metadataValueType: number; - public constructor(type: number) { - super(`Unsupported metadata type: "${type}"`); - this.type = type; + public constructor(metadataValueType: number) { + super(`Unsupported GGUF metadata value type "${metadataValueType}"`); + + Object.defineProperty(this, "metadataValueType" satisfies keyof this, {enumerable: false}); + + this.metadataValueType = metadataValueType; } } diff --git a/src/gguf/ggufParser/GgufParser.ts b/src/gguf/ggufParser/GgufParser.ts index 1eb89d46..5c32b91b 100644 --- a/src/gguf/ggufParser/GgufParser.ts +++ b/src/gguf/ggufParser/GgufParser.ts @@ -1,10 +1,10 @@ import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; -import UnsupportedMetadataTypeError from "../errors/UnsupportedMetadataTypeError.js"; +import UnsupportedGgufMetadataTypeError from "../errors/UnsupportedMetadataTypeError.js"; import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; import {GgufReadOffset} from "./utils/GgufReadOffset.js"; import {parseGgufFileTypeNumber} from "./utils/parseGgufFileTypeNumber.js"; import {GgufMetadataAny} from "./GgufMetadataTypes.js"; -import {GgufBaseFileReader, METHOD_TO_BYTE_COUNT} from "./fileReaders/GgufBaseFileReader.js"; +import {GgufFileReader, valueTypeToBytesToRead} from "./fileReaders/GgufFileReader.js"; const enum MetadataValueType { Uint8 = 0, @@ -21,6 +21,7 @@ const enum MetadataValueType { const ggufMagic = "GGUF"; +// these keys are ignored by default because they contain very long values that aren't very useful in the JS side of this library const defaultIgnoreMetadataKeys = [ "tokenizer.ggml.tokens", "tokenizer.ggml.scores", @@ -28,26 +29,26 @@ const defaultIgnoreMetadataKeys = [ "tokenizer.ggml.merges" ]; -export type GGUFMetadataResponse = { +export type GgufParsedMetadataResult = { metadataSize: number, metadata: GgufMetadataAny }; export type GgufParserOptions = { - fileReader: GgufBaseFileReader, + fileReader: GgufFileReader, ignoreKeys?: string[] }; export class GgufParser { - private readonly _fileReader: GgufBaseFileReader; - public ignoreKeys = defaultIgnoreMetadataKeys; + private readonly _fileReader: GgufFileReader; + public readonly ignoreKeys = defaultIgnoreMetadataKeys; public constructor({fileReader, ignoreKeys = defaultIgnoreMetadataKeys}: GgufParserOptions) { this.ignoreKeys = ignoreKeys; this._fileReader = fileReader; } - public async parseMetadata({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { + public async parseMetadata({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { const metadataRaw = await this._parseMetadataRaw(); const metadata: { [key: string]: any } = {}; @@ -99,17 +100,17 @@ export class GgufParser { return arrayValues; } - throw new UnsupportedMetadataTypeError(type); + throw new UnsupportedGgufMetadataTypeError(type); } private async _parseMetadataRaw(): Promise<{metadata: Record, metadataSize: number}> { const readOffset = new GgufReadOffset(0); - const magicBytes = await this._fileReader.readByteRange(readOffset, METHOD_TO_BYTE_COUNT.readUint8 * ggufMagic.length); - const magicText = String.fromCharCode(...magicBytes); + const fileMagicBytes = await this._fileReader.readByteRange(readOffset, valueTypeToBytesToRead.uint8 * ggufMagic.length); + const fileMagicText = String.fromCharCode(...fileMagicBytes); - if (magicText !== ggufMagic) - throw new InvalidGgufMagicError(); + if (fileMagicText !== ggufMagic) + throw new InvalidGgufMagicError(ggufMagic, fileMagicText); const version = await this._fileReader.readUint32(readOffset); const tensorCount = await this._fileReader.readUint64(readOffset); @@ -117,7 +118,7 @@ export class GgufParser { const metadata: { [key: string]: any } = { version, - tensorCount: GgufBaseFileReader.castNumber(tensorCount) + tensorCount: GgufFileReader.castNumber(tensorCount) }; for (let i = 0; i < metadataKVCount; i++) { diff --git a/src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts b/src/gguf/ggufParser/fileReaders/GgufFileReader.ts similarity index 77% rename from src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts rename to src/gguf/ggufParser/fileReaders/GgufFileReader.ts index 9c11cf8b..7f45cf22 100644 --- a/src/gguf/ggufParser/fileReaders/GgufBaseFileReader.ts +++ b/src/gguf/ggufParser/fileReaders/GgufFileReader.ts @@ -1,78 +1,76 @@ import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -export const METHOD_TO_BYTE_COUNT = { - readUint8: 1, - readUint16: 2, - readUint32: 4, - readUint64: 8, - readInt8: 1, - readInt16: 2, - readInt32: 4, - readInt64: 8, - readFloat32: 4, - readFloat64: 8, - readBool: 1 +export const valueTypeToBytesToRead = { + uint8: 1, + uint16: 2, + uint32: 4, + uint64: 8, + int8: 1, + int16: 2, + int32: 4, + int64: 8, + float32: 4, + float64: 8, + bool: 1 } as const; -export const ALLOCATION_SIZE = 1024 * 1024 * 1.5; // 1.5MB - -export abstract class GgufBaseFileReader { +export abstract class GgufFileReader { protected _buffer = Buffer.alloc(0); public abstract readByteRange(offset: number | GgufReadOffset, length: number): Promise; public async readUint8(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint8); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.uint8); return response.readUInt8(); } public async readUint16(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint16); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.uint16); return response.readUInt16LE(); } public async readUint32(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint32); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.uint32); return response.readUInt32LE(); } public async readUint64(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint64); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.uint64); return response.readBigUInt64LE(); } public async readInt8(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt8); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.int8); return response.readInt8(); } public async readInt16(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt16); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.int16); return response.readInt16LE(); } public async readInt32(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt32); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.int32); return response.readInt32LE(); } public async readInt64(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readInt64); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.int64); return response.readBigInt64LE(); } public async readFloat32(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readFloat32); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.float32); return response.readFloatLE(); } public async readFloat64(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readFloat64); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.float64); return response.readDoubleLE(); } public async readBool(offset: number | GgufReadOffset) { - const response = await this._readByteRangeAndUpdateOffset(offset, METHOD_TO_BYTE_COUNT.readUint8); + const response = await this._readByteRangeAndUpdateOffset(offset, valueTypeToBytesToRead.uint8); return response.readUInt8() === 1; } @@ -80,7 +78,7 @@ export abstract class GgufBaseFileReader { const readOffset = GgufReadOffset.resolveReadOffset(offset); const length = Number(await this.readUint64(readOffset)); - const readLength = METHOD_TO_BYTE_COUNT.readUint8 * length; + const readLength = valueTypeToBytesToRead.uint8 * length; const stringBytes = await this.readByteRange(readOffset, readLength); return String.fromCharCode(...stringBytes); diff --git a/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts b/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts index 23004c4a..cf253271 100644 --- a/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts +++ b/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts @@ -2,25 +2,19 @@ import fs from "node:fs/promises"; import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {GgufBaseFileReader, ALLOCATION_SIZE} from "./GgufBaseFileReader.js"; +import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../../consts.js"; +import {GgufFileReader} from "./GgufFileReader.js"; type GgufFsFileReaderOptions = { filePath: string, retryOptions?: retry.Options }; -const defaultRetryOptions: retry.Options = { - retries: 10, - factor: 2, - minTimeout: 1000, - maxTimeout: 1000 * 16 -} as const; - -export class GgufFsFileReader extends GgufBaseFileReader { +export class GgufFsFileReader extends GgufFileReader { public readonly filePath: string; public readonly retryOptions: retry.Options; - public constructor({filePath, retryOptions = defaultRetryOptions}: GgufFsFileReaderOptions) { + public constructor({filePath, retryOptions = ggufDefaultRetryOptions}: GgufFsFileReaderOptions) { super(); this.filePath = filePath; this.retryOptions = retryOptions; @@ -38,7 +32,7 @@ export class GgufFsFileReader extends GgufBaseFileReader { return res; } - private async _readToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = ALLOCATION_SIZE) { + private async _readToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = defaultExtraAllocationSize) { return await withLock(this, "modifyBuffer", async () => { if (endOffset < this._buffer.length) return; diff --git a/src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts b/src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.ts similarity index 70% rename from src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts rename to src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.ts index b64a091c..adcc370f 100644 --- a/src/gguf/ggufParser/fileReaders/GgufFetchFileReader.ts +++ b/src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.ts @@ -1,28 +1,25 @@ import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {GgufBaseFileReader, ALLOCATION_SIZE} from "./GgufBaseFileReader.js"; +import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../../consts.js"; +import {GgufFileReader} from "./GgufFileReader.js"; type GgufFetchFileReaderOptions = { url: string, - retryOptions?: retry.Options + retryOptions?: retry.Options, + headers?: Record }; -const defaultRetryOptions: retry.Options = { - retries: 10, - factor: 2, - minTimeout: 1000, - maxTimeout: 1000 * 16 -} as const; - -export class GgufFetchFileReader extends GgufBaseFileReader { +export class GgufNetworkFetchFileReader extends GgufFileReader { public readonly url: string; public readonly retryOptions: retry.Options; + public readonly headers: Record; - public constructor({url, retryOptions = defaultRetryOptions}: GgufFetchFileReaderOptions) { + public constructor({url, retryOptions = ggufDefaultRetryOptions, headers}: GgufFetchFileReaderOptions) { super(); this.url = url; this.retryOptions = retryOptions; + this.headers = headers ?? {}; } public async readByteRange(offset: number | GgufReadOffset, length: number) { @@ -37,7 +34,7 @@ export class GgufFetchFileReader extends GgufBaseFileReader { return res; } - private async _fetchToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = ALLOCATION_SIZE) { + private async _fetchToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = defaultExtraAllocationSize) { await withLock(this, "modifyBuffer", async () => { if (endOffset < this._buffer.length) return; @@ -52,10 +49,15 @@ export class GgufFetchFileReader extends GgufBaseFileReader { private async _fetchByteRange(start: number, length: number) { const response = await fetch(this.url, { headers: { + ...this.headers, Range: `bytes=${start}-${start + length}`, accept: "*/*" } }); + + if (!response.ok) + throw new Error(`Failed to fetch byte range: ${response.status} ${response.statusText}`); + const arrayBuffer = await response.arrayBuffer(); return Buffer.from(arrayBuffer); } diff --git a/src/gguf/parseGgufMetadata.ts b/src/gguf/parseGgufMetadata.ts new file mode 100644 index 00000000..4f16d5cf --- /dev/null +++ b/src/gguf/parseGgufMetadata.ts @@ -0,0 +1,41 @@ +import retry from "async-retry"; +import {GgufParser} from "./ggufParser/GgufParser.js"; +import {GgufNetworkFetchFileReader} from "./ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; +import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; +import {ggufDefaultRetryOptions} from "./consts.js"; + + +export async function parseGgufMetadata(pathOrUrl: string, { + sourceType, + retryOptions = ggufDefaultRetryOptions, + ignoreKeys = [] +}: { + sourceType?: "network" | "filesystem", + retryOptions?: retry.Options, + ignoreKeys?: string[] +} = {}) { + function createFileReader() { + if (sourceType === "network" || (sourceType == null && (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")))) { + return new GgufNetworkFetchFileReader({ + url: pathOrUrl, + retryOptions: retryOptions + }); + } else if (sourceType === "filesystem" || sourceType == null) { + return new GgufFsFileReader({ + filePath: pathOrUrl, + retryOptions: retryOptions + }); + } + + void (sourceType satisfies never); + throw new Error(`Unsupported sourceType: ${sourceType}`); + } + + const fileReader = createFileReader(); + const parser = new GgufParser({ + fileReader, + ignoreKeys + }); + + return await parser.parseMetadata(); +} diff --git a/src/utils/spawnCommand.ts b/src/utils/spawnCommand.ts index aa6ff7c2..dacaf7aa 100644 --- a/src/utils/spawnCommand.ts +++ b/src/utils/spawnCommand.ts @@ -93,9 +93,9 @@ export class SpawnError extends Error { public constructor(message: string, stdout: string, stderr: string, combinedStd: string) { super(message); - Object.defineProperty(this, "stdout", {enumerable: false}); - Object.defineProperty(this, "stderr", {enumerable: false}); - Object.defineProperty(this, "combinedStd", {enumerable: false}); + Object.defineProperty(this, "stdout" satisfies keyof this, {enumerable: false}); + Object.defineProperty(this, "stderr" satisfies keyof this, {enumerable: false}); + Object.defineProperty(this, "combinedStd" satisfies keyof this, {enumerable: false}); this.stdout = stdout; this.stderr = stderr; diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index 961ee842..ab886a89 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -2,9 +2,9 @@ import {describe, expect, it, test} from "vitest"; import {GgufFsFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFsFileReader.js"; import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; import {getModelFile} from "../../utils/modelFiles.js"; -import GGUFInsights from "../../../src/gguf/GGUFInsights.js"; +import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; -import GGUFMetadata from "../../../src/gguf/GGUFMetadata.js"; +import {parseGgufMetadata} from "../../../src/gguf/parseGgufMetadata.js"; describe("GGUF Parser", async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); @@ -32,7 +32,7 @@ describe("GGUF Parser", async () => { const metadata = await ggufParser.parseMetadata(); - const ggufInsights = new GGUFInsights(metadata); + const ggufInsights = new GgufInsights(metadata); const llama = await getTestLlama(); const model = await llama.loadModel({ @@ -48,10 +48,11 @@ describe("GGUF Parser", async () => { }); it("should fetch GGUF metadata", async () => { - const ggufMetadata = new GGUFMetadata(modelPath); - await ggufMetadata.parse(); + const ggufMetadataParseResult = await parseGgufMetadata(modelPath); - expect(ggufMetadata.metadata).toMatchSnapshot(); - expect(ggufMetadata.insights.VRAMUsage).toMatchInlineSnapshot("4474643028.666667"); + expect(ggufMetadataParseResult).toMatchSnapshot(); + + const insights = new GgufInsights(ggufMetadataParseResult); + expect(insights.VRAMUsage).toMatchInlineSnapshot("4474643028.666667"); }); }); diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 15fc14e0..95887331 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,12 +1,12 @@ import {describe, expect, it, test} from "vitest"; import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; -import {GgufFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFetchFileReader.js"; +import {GgufNetworkFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; describe("GGUF Parser", async () => { test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { - const fileReader = new GgufFetchFileReader({url: remoteGGUFModel}); + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); const magic = await fileReader.readByteRange(0, 4); const magicText = String.fromCharCode(...magic); @@ -15,7 +15,7 @@ describe("GGUF Parser", async () => { }); it("should parse remote gguf model", async () => { - const fileReader = new GgufFetchFileReader({url: remoteGGUFModel}); + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); const ggufParser = new GgufParser({fileReader: fileReader}); const metadata = await ggufParser.parseMetadata(); From d0bd7cc06a2fb6fe7ea267ef99e3d0b1f324a887 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 19 Mar 2024 21:22:02 +0200 Subject: [PATCH 11/52] feat: `inspect gguf` command --- .vitepress/config.ts | 10 +- .vitepress/utils/getCommandHtmlDoc.ts | 84 ++++++++++-- docs/guide/cli/cli.data.ts | 16 ++- docs/guide/cli/inspect.md | 2 +- docs/guide/cli/inspect/gguf.md | 17 +++ docs/guide/cli/inspect/gpu.md | 17 +++ src/cli/cli.ts | 2 +- src/cli/commands/InspectCommand.ts | 127 ------------------ src/cli/commands/inspect/InspectCommand.ts | 20 +++ .../inspect/commands/InspectGgufCommand.ts | 38 ++++++ .../inspect/commands/InspectGpuCommand.ts | 110 +++++++++++++++ src/utils/prettyPrintObject.ts | 44 ++++-- 12 files changed, 337 insertions(+), 150 deletions(-) create mode 100644 docs/guide/cli/inspect/gguf.md create mode 100644 docs/guide/cli/inspect/gpu.md delete mode 100644 src/cli/commands/InspectCommand.ts create mode 100644 src/cli/commands/inspect/InspectCommand.ts create mode 100644 src/cli/commands/inspect/commands/InspectGgufCommand.ts create mode 100644 src/cli/commands/inspect/commands/InspectGpuCommand.ts diff --git a/.vitepress/config.ts b/.vitepress/config.ts index 247d3c3b..dd48d2dd 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -199,7 +199,15 @@ export default defineConfig({ {text: "Download", link: "/download"}, {text: "Complete", link: "/complete"}, {text: "Infill", link: "/infill"}, - {text: "Inspect", link: "/inspect"}, + { + text: "Inspect", + link: "/inspect", + collapsed: true, + items: [ + {text: "GPU", link: "/inspect/gpu"}, + {text: "GGUF", link: "/inspect/gguf"}, + ] + }, {text: "Build", link: "/build"}, {text: "Clear", link: "/clear"} ] diff --git a/.vitepress/utils/getCommandHtmlDoc.ts b/.vitepress/utils/getCommandHtmlDoc.ts index 5af39b5c..c2041d99 100644 --- a/.vitepress/utils/getCommandHtmlDoc.ts +++ b/.vitepress/utils/getCommandHtmlDoc.ts @@ -4,13 +4,54 @@ import {cliBinName, npxRunPrefix} from "../../src/config.js"; import {buildHtmlTable} from "./buildHtmlTable.js"; import {buildHtmlHeading} from "./buildHtmlHeading.js"; -export async function getCommandHtmlDoc(command: CommandModule, cliName: string = cliBinName) { - const title = cliName + " " + (command.command ?? ""); +export async function getCommandHtmlDoc(command: CommandModule, { + cliName = cliBinName, + parentCommand, + subCommandsParentPageLink +}: { + cliName?: string, + parentCommand?: CommandModule, + subCommandsParentPageLink?: string +} = {}) { + const currentCommandCliCommand = resolveCommandCliCommand(command); + const resolvedParentCommandCliCommand = resolveCommandCliCommand(parentCommand); + const title = cliName + " " + (resolvedParentCommandCliCommand ?? "").replace("", currentCommandCliCommand ?? ""); const description = command.describe ?? ""; - const optionGroups = await getOptionsGroupFromCommand(command); + const {subCommands, optionGroups} = await parseCommandDefinition(command); let res = ""; + if (subCommands.length > 0) { + res += buildHtmlHeading("h2", htmlEscape("Commands"), "commands"); + + res += buildHtmlTable( + [ + "Command", + "Description" + ].map(htmlEscape), + subCommands + .map((subCommand) => { + if (subCommand.command == null || subCommand.describe === false) + return null; + + const resolvedCommandCliCommand = resolveCommandCliCommand(subCommand) ?? ""; + const commandPageLink = resolveCommandPageLink(subCommand); + + let cliCommand = resolvedCommandCliCommand; + cliCommand = (currentCommandCliCommand ?? "").replace("", cliCommand); + + if (parentCommand != null) + cliCommand = (resolvedParentCommandCliCommand ?? "").replace("", cliCommand); + + return [ + `` + htmlEscape(cliName + " " + cliCommand) + "", + htmlEscape(String(subCommand.describe ?? "")) + ]; + }) + .filter((row): row is string[] => row != null) + ); + } + if (optionGroups.length !== 0) { res += buildHtmlHeading("h2", htmlEscape("Options"), "options"); @@ -37,7 +78,10 @@ export async function getCommandHtmlDoc(command: CommandModule, cliNam } -async function getOptionsGroupFromCommand(command: CommandModule): Promise { +async function parseCommandDefinition(command: CommandModule): Promise<{ + subCommands: CommandModule[], + optionGroups: OptionsGroup[] +}> { const yargsStub = getYargsStub(); function getYargsStub() { function option(name: string, option: Options) { @@ -57,10 +101,16 @@ async function getOptionsGroupFromCommand(command: CommandModule): Pro return yargsStub; } - return {option}; + function command(subCommand: CommandModule) { + subCommands.push(subCommand); + return yargsStub; + } + + return {option, command}; } const options: Record = {}; + const subCommands: CommandModule[] = []; const groups: string[] = []; if (command.builder instanceof Function) @@ -97,10 +147,13 @@ async function getOptionsGroupFromCommand(command: CommandModule): Pro return 0; }); - return groups.map((group) => ({ - name: normalizeGroupName(group), - options: options[group]! - })); + return { + subCommands, + optionGroups: groups.map((group) => ({ + name: normalizeGroupName(group), + options: options[group]! + })) + }; } function normalizeGroupName(groupName: string): string { @@ -184,6 +237,19 @@ function renderOptionsGroupOptionsTable(options: {name: string, option: Options} return buildHtmlTable(tableHeaders, tableRows); } +function resolveCommandCliCommand(command?: CommandModule) { + if (command == null) + return undefined; + + return command.command instanceof Array + ? command.command[0] + : command.command; +} + +function resolveCommandPageLink(command: CommandModule) { + return resolveCommandCliCommand(command)?.split(" ")?.[0]; +} + type OptionsGroup = { name: string, options: Array<{ diff --git a/docs/guide/cli/cli.data.ts b/docs/guide/cli/cli.data.ts index 0c84bf87..9161f0cc 100644 --- a/docs/guide/cli/cli.data.ts +++ b/docs/guide/cli/cli.data.ts @@ -4,7 +4,9 @@ import {BuildCommand} from "../../../src/cli/commands/BuildCommand.js"; import {ChatCommand} from "../../../src/cli/commands/ChatCommand.js"; import {CompleteCommand} from "../../../src/cli/commands/CompleteCommand.js"; import {InfillCommand} from "../../../src/cli/commands/InfillCommand.js"; -import {InspectCommand} from "../../../src/cli/commands/InspectCommand.js"; +import {InspectCommand} from "../../../src/cli/commands/inspect/InspectCommand.js"; +import {InspectGpuCommand} from "../../../src/cli/commands/inspect/commands/InspectGpuCommand.js"; +import {InspectGgufCommand} from "../../../src/cli/commands/inspect/commands/InspectGgufCommand.js"; import {DownloadCommand} from "../../../src/cli/commands/DownloadCommand.js"; import {ClearCommand} from "../../../src/cli/commands/ClearCommand.js"; import {htmlEscape} from "../../../.vitepress/utils/htmlEscape.js"; @@ -31,7 +33,17 @@ export default { chat: await getCommandHtmlDoc(ChatCommand), complete: await getCommandHtmlDoc(CompleteCommand), infill: await getCommandHtmlDoc(InfillCommand), - inspect: await getCommandHtmlDoc(InspectCommand), + inspect: { + index: await getCommandHtmlDoc(InspectCommand, { + subCommandsParentPageLink: "inspect" + }), + gpu: await getCommandHtmlDoc(InspectGpuCommand, { + parentCommand: InspectCommand + }), + gguf: await getCommandHtmlDoc(InspectGgufCommand, { + parentCommand: InspectCommand + }) + }, download: await getCommandHtmlDoc(DownloadCommand), build: await getCommandHtmlDoc(BuildCommand), clear: await getCommandHtmlDoc(ClearCommand) diff --git a/docs/guide/cli/inspect.md b/docs/guide/cli/inspect.md index e2edfebd..c55d6bc3 100644 --- a/docs/guide/cli/inspect.md +++ b/docs/guide/cli/inspect.md @@ -5,7 +5,7 @@ outline: deep {{commandDoc.description}} diff --git a/docs/guide/cli/inspect/gguf.md b/docs/guide/cli/inspect/gguf.md new file mode 100644 index 00000000..5fd2b097 --- /dev/null +++ b/docs/guide/cli/inspect/gguf.md @@ -0,0 +1,17 @@ +--- +outline: deep +--- +# `inspect gguf` command + + + +{{commandDoc.description}} + +## Usage +```shell-vue +{{commandDoc.usage}} +``` +
diff --git a/docs/guide/cli/inspect/gpu.md b/docs/guide/cli/inspect/gpu.md new file mode 100644 index 00000000..a397187d --- /dev/null +++ b/docs/guide/cli/inspect/gpu.md @@ -0,0 +1,17 @@ +--- +outline: deep +--- +# `inspect gpu` command + + + +{{commandDoc.description}} + +## Usage +```shell-vue +{{commandDoc.usage}} +``` +
diff --git a/src/cli/cli.ts b/src/cli/cli.ts index 05c0820d..8e836a94 100644 --- a/src/cli/cli.ts +++ b/src/cli/cli.ts @@ -15,7 +15,7 @@ import {ClearCommand} from "./commands/ClearCommand.js"; import {ChatCommand} from "./commands/ChatCommand.js"; import {CompleteCommand} from "./commands/CompleteCommand.js"; import {InfillCommand} from "./commands/InfillCommand.js"; -import {InspectCommand} from "./commands/InspectCommand.js"; +import {InspectCommand} from "./commands/inspect/InspectCommand.js"; import {DebugCommand} from "./commands/DebugCommand.js"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); diff --git a/src/cli/commands/InspectCommand.ts b/src/cli/commands/InspectCommand.ts deleted file mode 100644 index 285522e3..00000000 --- a/src/cli/commands/InspectCommand.ts +++ /dev/null @@ -1,127 +0,0 @@ -import os from "os"; -import {CommandModule} from "yargs"; -import bytes from "bytes"; -import chalk from "chalk"; -import {getLlamaForOptions} from "../../bindings/getLlama.js"; -import {detectAvailableComputeLayers} from "../../bindings/utils/detectAvailableComputeLayers.js"; -import {getPlatform} from "../../bindings/utils/getPlatform.js"; -import {BuildGpu, LlamaLogLevel} from "../../bindings/types.js"; -import {getPrettyBuildGpuName} from "../../bindings/consts.js"; - -const inspectFunctions = ["gpu"] as const; -type InspectCommand = { - function: (typeof inspectFunctions)[number] -}; - -export const InspectCommand: CommandModule = { - command: "inspect [function]", - describe: "Inspect the inner workings of node-llama-cpp", - builder(yargs) { - return yargs - .option("function", { - type: "string", - choices: inspectFunctions, - demandOption: true, - description: "inspect function to run" - }); - }, - async handler({function: func}: InspectCommand) { - if (func === "gpu") - await InspectGpuFunction(); - else - void (func satisfies never); - } -}; - -async function InspectGpuFunction() { - const platform = getPlatform(); - const arch = process.arch; - const availableComputeLayers = await detectAvailableComputeLayers({platform}); - const gpusToLogVramUsageOf: BuildGpu[] = []; - - if (platform === "mac" && arch === "arm64") { - console.info(`${chalk.yellow("Metal:")} ${chalk.green("available")}`); - gpusToLogVramUsageOf.push("metal"); - } else if (platform === "mac") { - console.info(`${chalk.yellow("Metal:")} ${chalk.red("not supported by llama.cpp on Intel Macs")}`); - } - - if (availableComputeLayers.cuda.hasNvidiaDriver && !availableComputeLayers.cuda.hasCudaRuntime) { - console.info(`${chalk.yellow("CUDA:")} ${chalk.red("NVIDIA driver is installed, but CUDA runtime is not")}`); - } else if (availableComputeLayers.cuda.hasCudaRuntime && !availableComputeLayers.cuda.hasNvidiaDriver) { - console.info(`${chalk.yellow("CUDA:")} ${chalk.red("CUDA runtime is installed, but NVIDIA driver is not")}`); - } else if (availableComputeLayers.cuda.hasCudaRuntime && availableComputeLayers.cuda.hasNvidiaDriver) { - console.info(`${chalk.yellow("CUDA:")} ${chalk.green("available")}`); - gpusToLogVramUsageOf.push("cuda"); - } - - if (availableComputeLayers.vulkan) { - console.info(`${chalk.yellow("Vulkan:")} ${chalk.green("available")}`); - gpusToLogVramUsageOf.push("vulkan"); - } - - for (const gpu of gpusToLogVramUsageOf) { - console.info(); - await logGpuVramUsage(gpu); - } - - console.info(); - await logRamUsage(); -} - -async function logGpuVramUsage(gpu: BuildGpu) { - try { - const llama = await getLlamaForOptions({ - gpu: gpu, - build: "never", - progressLogs: false, - logLevel: LlamaLogLevel.warn - }, { - skipLlamaInit: true - }); - const gpuName = getPrettyBuildGpuName(gpu); - const vramStatus = llama.getVramState(); - - console.info(`${chalk.yellow(`${gpuName} used VRAM:`)} ${getPercentageString(vramStatus.used, vramStatus.total)}% ${chalk.grey("(" + bytes(vramStatus.used) + "/" + bytes(vramStatus.total) + ")")}`); - console.info(`${chalk.yellow(`${gpuName} free VRAM:`)} ${getPercentageString(vramStatus.free, vramStatus.total)}% ${chalk.grey("(" + bytes(vramStatus.free) + "/" + bytes(vramStatus.total) + ")")}`); - } catch (err) {} -} - -async function logRamUsage() { - const totalMemory = os.totalmem(); - const freeMemory = os.freemem(); - const usedMemory = totalMemory - freeMemory; - - console.info(`${chalk.yellow("Used RAM:")} ${getPercentageString(usedMemory, totalMemory)}% ${chalk.grey("(" + bytes(usedMemory) + "/" + bytes(totalMemory) + ")")}`); - console.info(`${chalk.yellow("Free RAM:")} ${getPercentageString(freeMemory, totalMemory)}% ${chalk.grey("(" + bytes(freeMemory) + "/" + bytes(totalMemory) + ")")}`); -} - -function getPercentageString(amount: number, total: number) { - if (total === 0) - return "0"; - - return String(Math.floor((amount / total) * 100 * 100) / 100); -} - -// // simple script to copy console logs as ansi to clipboard. Used to update the documentation -// import {spawn} from "child_process"; -// const pendingLog: string[] = []; -// const originalConsoleInfo = console.info; -// console.info = function info(...args: any[]) { -// originalConsoleInfo.call(console, ...args); -// pendingLog.push(args.join(" ")); -// }; -// -// function copyLogs() { -// const res = pendingLog.join("\n"); -// -// pbcopy(res); -// originalConsoleInfo.call(console, "Copied logs to clipboard"); -// } -// function pbcopy(text: string) { -// const pbcopyProcess = spawn("pbcopy"); -// pbcopyProcess.stdin.write(text); -// pbcopyProcess.stdin.end(); -// } -// -// process.on("exit", copyLogs); diff --git a/src/cli/commands/inspect/InspectCommand.ts b/src/cli/commands/inspect/InspectCommand.ts new file mode 100644 index 00000000..504c4e2e --- /dev/null +++ b/src/cli/commands/inspect/InspectCommand.ts @@ -0,0 +1,20 @@ +import {CommandModule} from "yargs"; +import {InspectGgufCommand} from "./commands/InspectGgufCommand.js"; +import {InspectGpuCommand} from "./commands/InspectGpuCommand.js"; + +type InspectCommand = { + // no options for now +}; + +export const InspectCommand: CommandModule = { + command: "inspect ", + describe: "Inspect the inner workings of node-llama-cpp", + builder(yargs) { + return yargs + .command(InspectGpuCommand) + .command(InspectGgufCommand); + }, + async handler() { + // this function must exit, even though we do nothing here + } +}; diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts new file mode 100644 index 00000000..033d1097 --- /dev/null +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -0,0 +1,38 @@ +import path from "path"; +import {CommandModule} from "yargs"; +import chalk from "chalk"; +import bytes from "bytes"; +import {parseGgufMetadata} from "../../../../gguf/parseGgufMetadata.js"; +import {prettyPrintObject} from "../../../../utils/prettyPrintObject.js"; + +type InspectGgufCommand = { + path: string +}; + +export const InspectGgufCommand: CommandModule = { + command: "gguf [path]", + describe: "Inspect a GGUF file", + builder(yargs) { + return yargs + .option("path", { + type: "string", + demandOption: true, + description: "The path to the GGUF file to inspect" + }); + }, + async handler({path: ggufPath}: InspectGgufCommand) { + const isPathUrl = ggufPath.startsWith("http://") || ggufPath.startsWith("https://"); + const resolvedGgufPath = isPathUrl + ? ggufPath + : path.resolve(ggufPath); + + if (isPathUrl) + console.info(`${chalk.yellow("URL:")} ${resolvedGgufPath}`); + else + console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); + + const parsedMetadata = await parseGgufMetadata(ggufPath, {ignoreKeys: []}); + console.info(`${chalk.yellow("GGUF metadata size:")} ${bytes(parsedMetadata.metadataSize)}`); + console.info(`${chalk.yellow("GGUF metadata:")} ${prettyPrintObject(parsedMetadata.metadata, undefined, {maxArrayValues: 10, useNumberGrouping: true})}`); + } +}; diff --git a/src/cli/commands/inspect/commands/InspectGpuCommand.ts b/src/cli/commands/inspect/commands/InspectGpuCommand.ts new file mode 100644 index 00000000..7c918fd8 --- /dev/null +++ b/src/cli/commands/inspect/commands/InspectGpuCommand.ts @@ -0,0 +1,110 @@ +import os from "os"; +import {CommandModule} from "yargs"; +import bytes from "bytes"; +import chalk from "chalk"; +import {getLlamaForOptions} from "../../../../bindings/getLlama.js"; +import {detectAvailableComputeLayers} from "../../../../bindings/utils/detectAvailableComputeLayers.js"; +import {getPlatform} from "../../../../bindings/utils/getPlatform.js"; +import {BuildGpu, LlamaLogLevel} from "../../../../bindings/types.js"; +import {getPrettyBuildGpuName} from "../../../../bindings/consts.js"; + +type InspectGpuCommand = { + // no options for now +}; + +export const InspectGpuCommand: CommandModule = { + command: "gpu", + describe: "Show the detected GPU types and their VRAM usage", + async handler() { + const platform = getPlatform(); + const arch = process.arch; + const availableComputeLayers = await detectAvailableComputeLayers({platform}); + const gpusToLogVramUsageOf: BuildGpu[] = []; + + if (platform === "mac" && arch === "arm64") { + console.info(`${chalk.yellow("Metal:")} ${chalk.green("available")}`); + gpusToLogVramUsageOf.push("metal"); + } else if (platform === "mac") { + console.info(`${chalk.yellow("Metal:")} ${chalk.red("not supported by llama.cpp on Intel Macs")}`); + } + + if (availableComputeLayers.cuda.hasNvidiaDriver && !availableComputeLayers.cuda.hasCudaRuntime) { + console.info(`${chalk.yellow("CUDA:")} ${chalk.red("NVIDIA driver is installed, but CUDA runtime is not")}`); + } else if (availableComputeLayers.cuda.hasCudaRuntime && !availableComputeLayers.cuda.hasNvidiaDriver) { + console.info(`${chalk.yellow("CUDA:")} ${chalk.red("CUDA runtime is installed, but NVIDIA driver is not")}`); + } else if (availableComputeLayers.cuda.hasCudaRuntime && availableComputeLayers.cuda.hasNvidiaDriver) { + console.info(`${chalk.yellow("CUDA:")} ${chalk.green("available")}`); + gpusToLogVramUsageOf.push("cuda"); + } + + if (availableComputeLayers.vulkan) { + console.info(`${chalk.yellow("Vulkan:")} ${chalk.green("available")}`); + gpusToLogVramUsageOf.push("vulkan"); + } + + for (const gpu of gpusToLogVramUsageOf) { + console.info(); + await logGpuVramUsage(gpu); + } + + console.info(); + await logRamUsage(); + } +}; + +async function logGpuVramUsage(gpu: BuildGpu) { + try { + const llama = await getLlamaForOptions({ + gpu: gpu, + build: "never", + progressLogs: false, + logLevel: LlamaLogLevel.warn + }, { + skipLlamaInit: true + }); + const gpuName = getPrettyBuildGpuName(gpu); + const vramStatus = llama.getVramState(); + + console.info(`${chalk.yellow(`${gpuName} used VRAM:`)} ${getPercentageString(vramStatus.used, vramStatus.total)}% ${chalk.grey("(" + bytes(vramStatus.used) + "/" + bytes(vramStatus.total) + ")")}`); + console.info(`${chalk.yellow(`${gpuName} free VRAM:`)} ${getPercentageString(vramStatus.free, vramStatus.total)}% ${chalk.grey("(" + bytes(vramStatus.free) + "/" + bytes(vramStatus.total) + ")")}`); + } catch (err) {} +} + +async function logRamUsage() { + const totalMemory = os.totalmem(); + const freeMemory = os.freemem(); + const usedMemory = totalMemory - freeMemory; + + console.info(`${chalk.yellow("Used RAM:")} ${getPercentageString(usedMemory, totalMemory)}% ${chalk.grey("(" + bytes(usedMemory) + "/" + bytes(totalMemory) + ")")}`); + console.info(`${chalk.yellow("Free RAM:")} ${getPercentageString(freeMemory, totalMemory)}% ${chalk.grey("(" + bytes(freeMemory) + "/" + bytes(totalMemory) + ")")}`); +} + +function getPercentageString(amount: number, total: number) { + if (total === 0) + return "0"; + + return String(Math.floor((amount / total) * 100 * 100) / 100); +} + +// // simple script to copy console logs as ansi to clipboard. Used to update the documentation +// import {spawn} from "child_process"; +// const pendingLog: string[] = []; +// const originalConsoleInfo = console.info; +// console.info = function info(...args: any[]) { +// originalConsoleInfo.call(console, ...args); +// pendingLog.push(args.join(" ")); +// }; +// +// function copyLogs() { +// const res = pendingLog.join("\n"); +// +// pbcopy(res); +// originalConsoleInfo.call(console, "Copied logs to clipboard"); +// } +// function pbcopy(text: string) { +// const pbcopyProcess = spawn("pbcopy"); +// pbcopyProcess.stdin.write(text); +// pbcopyProcess.stdin.end(); +// } +// +// process.on("exit", copyLogs); diff --git a/src/utils/prettyPrintObject.ts b/src/utils/prettyPrintObject.ts index d9af60d7..69947b2e 100644 --- a/src/utils/prettyPrintObject.ts +++ b/src/utils/prettyPrintObject.ts @@ -1,10 +1,22 @@ import chalk from "chalk"; -export function prettyPrintObject(obj: any, indent: number = 4): string { +type PrettyPrintObjectOptions = { + maxArrayValues?: number, + useNumberGrouping?: boolean +}; + +export function prettyPrintObject(obj: any, indent: number = 4, options: PrettyPrintObjectOptions = {}): string { if (typeof obj === "string") return chalk.green(JSON.stringify(obj, null, 4)); else if (typeof obj === "number") - return chalk.yellow(obj); + return chalk.yellow( + options.useNumberGrouping + ? obj.toLocaleString("en-US", { + style: "decimal", + useGrouping: true + }).replaceAll(",", "_") + : obj + ); else if (typeof obj === "boolean") return chalk.magenta.italic(obj); else if (obj === null) @@ -12,12 +24,7 @@ export function prettyPrintObject(obj: any, indent: number = 4): string { else if (obj === undefined) return chalk.magenta.italic("undefined"); else if (obj instanceof Array) - return [ - chalk.whiteBright("["), - obj.map(prettyPrintObject) - .join(chalk.whiteBright(", ")), - chalk.whiteBright("]") - ].join(""); + return prettyPrintArray(obj, indent, options); const rows: string[] = []; for (const key of Object.keys(obj)) { @@ -29,7 +36,7 @@ export function prettyPrintObject(obj: any, indent: number = 4): string { ? chalk.red(key) : chalk.green(JSON.stringify(key)), chalk.whiteBright(": "), - prettyPrintObject(value, indent) + prettyPrintObject(value, indent, options) .replaceAll("\n", "\n" + " ".repeat(indent)) ].join("")); } @@ -43,3 +50,22 @@ export function prettyPrintObject(obj: any, indent: number = 4): string { function canStringBeKeyWithoutQuotes(key: string): boolean { return JSON.stringify(key).slice(1, -1) === key && /^[a-zA-Z_][a-zA-Z0-9_]*$/.test(key); } + +function prettyPrintArray(arr: any[], indent: number = 4, options: PrettyPrintObjectOptions = {}) { + const slicedArray = (options.maxArrayValues != null && arr.length > options.maxArrayValues) + ? arr.slice(0, options.maxArrayValues) + : arr; + const hiddenItems = arr.length - slicedArray.length; + + return [ + chalk.whiteBright("["), + slicedArray.map((item) => prettyPrintObject(item, indent, options)) + .concat( + hiddenItems > 0 + ? [chalk.white("..." + hiddenItems + " more item" + (hiddenItems !== 1 ? "s" : ""))] + : [] + ) + .join(chalk.whiteBright(", ")), + chalk.whiteBright("]") + ].join(""); +} From ab94b639f02039f728339cd94754861aa2be4039 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 19 Mar 2024 21:24:11 +0200 Subject: [PATCH 12/52] fix: rename `build` option to `builds` in the `clear` command --- src/cli/commands/ClearCommand.ts | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/cli/commands/ClearCommand.ts b/src/cli/commands/ClearCommand.ts index 9fb446a1..fa696f74 100644 --- a/src/cli/commands/ClearCommand.ts +++ b/src/cli/commands/ClearCommand.ts @@ -7,7 +7,7 @@ import {clearAllLocalBuilds} from "../../bindings/utils/clearAllLocalBuilds.js"; import {clearLocalCmake, fixXpackPermissions} from "../../utils/cmake.js"; type ClearCommand = { - type: "source" | "build" | "cmake" | "all" + type: "source" | "builds" | "cmake" | "all" }; export const ClearCommand: CommandModule = { @@ -18,7 +18,7 @@ export const ClearCommand: CommandModule = { return yargs .option("type", { type: "string", - choices: ["source", "build", "cmake", "all"] satisfies ClearCommand["type"][], + choices: ["source", "builds", "cmake", "all"] satisfies ClearCommand["type"][], default: "all" as ClearCommand["type"], description: "Files to clear" }); @@ -38,11 +38,11 @@ export async function ClearLlamaCppBuildCommand({type}: ClearCommand) { }); } - if (type === "build" || type === "all") { + if (type === "builds" || type === "all") { await withOra({ - loading: chalk.blue("Clearing build"), - success: chalk.blue("Cleared build"), - fail: chalk.blue("Failed to clear build") + loading: chalk.blue("Clearing all builds"), + success: chalk.blue("Cleared all builds"), + fail: chalk.blue("Failed to clear all builds") }, async () => { await clearAllLocalBuilds(); }); From 63ba9b9971b1f701af9a22b816c4a134cdb0bbcf Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 02:51:30 +0200 Subject: [PATCH 13/52] feat: read tensor info from a GGUF file --- .../inspect/commands/InspectGgufCommand.ts | 46 +- src/gguf/GgufInsights.ts | 44 +- .../errors/UnsupportedGgufValueTypeError.ts | 11 + .../errors/UnsupportedMetadataTypeError.ts | 11 - ...arseGgufMetadata.ts => getGgufFileInfo.ts} | 12 +- src/gguf/ggufParser/GgufMetadataTypes.ts | 362 --------------- src/gguf/ggufParser/GgufParser.ts | 136 +++--- .../ggufParser/types/GgufFileInfoTypes.ts | 15 + .../ggufParser/types/GgufMetadataTypes.ts | 422 ++++++++++++++++++ .../ggufParser/types/GgufTensorInfoTypes.ts | 38 ++ .../ggufParser/utils/getGgufFileTypeName.ts | 14 + .../utils/getGgufMetadataLlmData.ts | 10 + .../utils/parseGgufFileTypeNumber.ts | 34 -- src/utils/mergeUnionTypes.ts | 4 + src/utils/prettyPrintObject.ts | 79 +++- .../__snapshots__/gguf.test.ts.snap | 160 ++++++- test/modelDependent/functionary/gguf.test.ts | 18 +- .../gguf/__snapshots__/gguf.test.ts.snap | 177 +++++++- test/standalone/gguf/gguf.test.ts | 21 +- .../simplifyGgufInfoForTestSnapshot.ts | 23 + 20 files changed, 1109 insertions(+), 528 deletions(-) create mode 100644 src/gguf/errors/UnsupportedGgufValueTypeError.ts delete mode 100644 src/gguf/errors/UnsupportedMetadataTypeError.ts rename src/gguf/{parseGgufMetadata.ts => getGgufFileInfo.ts} (80%) delete mode 100644 src/gguf/ggufParser/GgufMetadataTypes.ts create mode 100644 src/gguf/ggufParser/types/GgufFileInfoTypes.ts create mode 100644 src/gguf/ggufParser/types/GgufMetadataTypes.ts create mode 100644 src/gguf/ggufParser/types/GgufTensorInfoTypes.ts create mode 100644 src/gguf/ggufParser/utils/getGgufFileTypeName.ts create mode 100644 src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts delete mode 100644 src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts create mode 100644 test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index 033d1097..92e362b8 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -2,11 +2,13 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import bytes from "bytes"; -import {parseGgufMetadata} from "../../../../gguf/parseGgufMetadata.js"; -import {prettyPrintObject} from "../../../../utils/prettyPrintObject.js"; +import {getGgufFileInfo} from "../../../../gguf/getGgufFileInfo.js"; +import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; +import {getGgufFileTypeName} from "../../../../gguf/ggufParser/utils/getGgufFileTypeName.js"; type InspectGgufCommand = { - path: string + path: string, + fullTensorInfo: boolean }; export const InspectGgufCommand: CommandModule = { @@ -18,9 +20,15 @@ export const InspectGgufCommand: CommandModule = { type: "string", demandOption: true, description: "The path to the GGUF file to inspect" + }) + .option("fullTensorInfo", { + alias: "t", + type: "boolean", + default: false, + description: "Show the full tensor info" }); }, - async handler({path: ggufPath}: InspectGgufCommand) { + async handler({path: ggufPath, fullTensorInfo}: InspectGgufCommand) { const isPathUrl = ggufPath.startsWith("http://") || ggufPath.startsWith("https://"); const resolvedGgufPath = isPathUrl ? ggufPath @@ -31,8 +39,32 @@ export const InspectGgufCommand: CommandModule = { else console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); - const parsedMetadata = await parseGgufMetadata(ggufPath, {ignoreKeys: []}); - console.info(`${chalk.yellow("GGUF metadata size:")} ${bytes(parsedMetadata.metadataSize)}`); - console.info(`${chalk.yellow("GGUF metadata:")} ${prettyPrintObject(parsedMetadata.metadata, undefined, {maxArrayValues: 10, useNumberGrouping: true})}`); + const parsedMetadata = await getGgufFileInfo(ggufPath, {ignoreKeys: []}); + const fileTypeName = getGgufFileTypeName(parsedMetadata.metadata.general?.file_type); + const metadataPrettyPrintOptions: PrettyPrintObjectOptions = { + maxArrayValues: 10, + useNumberGrouping: true, + maxArrayItemsWidth: process.stdout.columns - 1 + }; + const tensorInfoPrettyPrintOptions: PrettyPrintObjectOptions = { + maxArrayValues: fullTensorInfo + ? undefined + : 4, + useNumberGrouping: true, + maxArrayItemsWidth: process.stdout.columns - 1, + multilineObjects: false + }; + const numberLocaleFormattingOptions = { + style: "decimal", + useGrouping: true + } as const; + + console.info(`${chalk.yellow("GGUF version:")} ${parsedMetadata.version}`); + console.info(`${chalk.yellow("Tensor count:")} ${parsedMetadata.tensorCount.toLocaleString("en-US", numberLocaleFormattingOptions)}`); + console.info(`${chalk.yellow("Metadata size:")} ${bytes(parsedMetadata.metadataSize)}`); + console.info(`${chalk.yellow("Tensor info size:")} ${bytes(parsedMetadata.tensorInfoSize!)}`); + console.info(`${chalk.yellow("File type:")} ${fileTypeName ?? ""} ${chalk.white(`(${parsedMetadata.metadata.general?.file_type})`)}`); + console.info(`${chalk.yellow("Metadata:")} ${prettyPrintObject(parsedMetadata.metadata, undefined, metadataPrettyPrintOptions)}`); + console.info(`${chalk.yellow("Tensor info:")} ${prettyPrintObject(parsedMetadata.tensorInfo, undefined, tensorInfoPrettyPrintOptions)}`); } }; diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 0ba9e8b4..68e399d3 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,18 +1,19 @@ -import {GgufParsedMetadataResult} from "./ggufParser/GgufParser.js"; +import {getGgufMetadataLlmData} from "./ggufParser/utils/getGgufMetadataLlmData.js"; +import {GgufMetadata} from "./ggufParser/types/GgufMetadataTypes.js"; export class GgufInsights { - public readonly metadataResponse: GgufParsedMetadataResult; + public readonly metadata: GgufMetadata; + public readonly metadataSize: number; - public constructor(metadataResponse: GgufParsedMetadataResult) { - this.metadataResponse = metadataResponse; - } - - public get metadata() { - return this.metadataResponse.metadata; - } - - public get architectureMetadata() { - return this.metadata[this.metadata.general.architecture]; + public constructor({ + metadata, + metadataSize + }: { + metadata: GgufMetadata, + metadataSize: number + }) { + this.metadata = metadata; + this.metadataSize = metadataSize; } /** @@ -20,13 +21,14 @@ export class GgufInsights { */ public get kvMatrices() { // 2 bytes each * 2 key and value + const llmData = getGgufMetadataLlmData(this.metadata); return ( 2 * 2 * - this.architectureMetadata.context_length * - this.architectureMetadata.block_count * - this.architectureMetadata.embedding_length * - this.architectureMetadata.attention.head_count_kv / - this.architectureMetadata.attention.head_count + (llmData.context_length ?? 1) * + (llmData.block_count ?? 1) * + (llmData.embedding_length ?? 1) * + (llmData.attention?.head_count_kv ?? 1) / + (llmData.attention?.head_count ?? 1) ); } @@ -36,10 +38,14 @@ export class GgufInsights { public get graphSize() { // TODO: get this from the llama.cpp's graph calculations instead of // estimating it's 1/6 * kv_cache_size * num_gqa - return (this.architectureMetadata.attention.head_count_kv / this.architectureMetadata.attention.head_count) * this.kvMatrices / 6; + const llmData = getGgufMetadataLlmData(this.metadata); + return ( + (llmData.attention?.head_count_kv ?? 1) / + (llmData.attention?.head_count ?? 1) + ) * this.kvMatrices / 6; } public get VRAMUsage() { - return this.graphSize + this.kvMatrices + this.metadataResponse.metadataSize; + return this.graphSize + this.kvMatrices + this.metadataSize; } } diff --git a/src/gguf/errors/UnsupportedGgufValueTypeError.ts b/src/gguf/errors/UnsupportedGgufValueTypeError.ts new file mode 100644 index 00000000..564fcb06 --- /dev/null +++ b/src/gguf/errors/UnsupportedGgufValueTypeError.ts @@ -0,0 +1,11 @@ +export class UnsupportedGgufValueTypeError extends Error { + public readonly ggufValueType: number; + + public constructor(ggufValueType: number) { + super(`Unsupported GGUF value type "${ggufValueType}"`); + + Object.defineProperty(this, "ggufValueType" satisfies keyof this, {enumerable: false}); + + this.ggufValueType = ggufValueType; + } +} diff --git a/src/gguf/errors/UnsupportedMetadataTypeError.ts b/src/gguf/errors/UnsupportedMetadataTypeError.ts deleted file mode 100644 index 3be0a8b9..00000000 --- a/src/gguf/errors/UnsupportedMetadataTypeError.ts +++ /dev/null @@ -1,11 +0,0 @@ -export default class UnsupportedMetadataTypeError extends Error { - public readonly metadataValueType: number; - - public constructor(metadataValueType: number) { - super(`Unsupported GGUF metadata value type "${metadataValueType}"`); - - Object.defineProperty(this, "metadataValueType" satisfies keyof this, {enumerable: false}); - - this.metadataValueType = metadataValueType; - } -} diff --git a/src/gguf/parseGgufMetadata.ts b/src/gguf/getGgufFileInfo.ts similarity index 80% rename from src/gguf/parseGgufMetadata.ts rename to src/gguf/getGgufFileInfo.ts index 4f16d5cf..137fcb18 100644 --- a/src/gguf/parseGgufMetadata.ts +++ b/src/gguf/getGgufFileInfo.ts @@ -5,11 +5,16 @@ import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; import {ggufDefaultRetryOptions} from "./consts.js"; -export async function parseGgufMetadata(pathOrUrl: string, { +/** + * Parse a GGUF file and return its metadata and tensor info (unless `readTensorInfo` is set to `false`) + */ +export async function getGgufFileInfo(pathOrUrl: string, { + readTensorInfo = true, sourceType, retryOptions = ggufDefaultRetryOptions, ignoreKeys = [] }: { + readTensorInfo?: boolean, sourceType?: "network" | "filesystem", retryOptions?: retry.Options, ignoreKeys?: string[] @@ -34,8 +39,9 @@ export async function parseGgufMetadata(pathOrUrl: string, { const fileReader = createFileReader(); const parser = new GgufParser({ fileReader, - ignoreKeys + ignoreKeys, + readTensorInfo }); - return await parser.parseMetadata(); + return await parser.parseFileInfo(); } diff --git a/src/gguf/ggufParser/GgufMetadataTypes.ts b/src/gguf/ggufParser/GgufMetadataTypes.ts deleted file mode 100644 index a4a142b9..00000000 --- a/src/gguf/ggufParser/GgufMetadataTypes.ts +++ /dev/null @@ -1,362 +0,0 @@ -export const enum GgufArchitectureType { - llama = "llama", - falcon = "falcon", - mpt = "mpt", - gptneox = "gptneox", - gptj = "gptj", - gpt2 = "gpt2", - bloom = "bloom", - rwkv = "rwkv", - whisper = "whisper" -} - -export const enum GgufFileType { - ALL_F32 = "ALL_F32", - MOSTLY_F16 = "MOSTLY_F16", - MOSTLY_Q4_0 = "MOSTLY_Q4_0", - MOSTLY_Q4_1 = "MOSTLY_Q4_1", - MOSTLY_Q4_1_SOME_F16 = "MOSTLY_Q4_1_SOME_F16", - MOSTLY_Q4_2 = "MOSTLY_Q4_2", - MOSTLY_Q4_3 = "MOSTLY_Q4_3", - MOSTLY_Q8_0 = "MOSTLY_Q8_0", - MOSTLY_Q5_0 = "MOSTLY_Q5_0", - MOSTLY_Q5_1 = "MOSTLY_Q5_1", - MOSTLY_Q2_K = "MOSTLY_Q2_K", - MOSTLY_Q3_K_S = "MOSTLY_Q3_K_S", - MOSTLY_Q3_K_M = "MOSTLY_Q3_K_M", - MOSTLY_Q3_K_L = "MOSTLY_Q3_K_L", - MOSTLY_Q4_K_S = "MOSTLY_Q4_K_S", - MOSTLY_Q4_K_M = "MOSTLY_Q4_K_M", - MOSTLY_Q5_K_S = "MOSTLY_Q5_K_S", - MOSTLY_Q5_K_M = "MOSTLY_Q5_K_M", - MOSTLY_Q6_K = "MOSTLY_Q6_K" -} - - -export type GgufMetadataArchitectureProperties = { - context_length: number, - embedding_length: number, - block_count: number, - feed_forward_length: number, - use_parallel_residual: boolean, - tensor_data_layout: string, - expert_count: number, - expert_used_count: number, - - attention: { - head_count: number, - head_count_kv: number, - max_alibi_bias: number, - clamp_kqv: number, - key_length: number, - value_length: number, - layer_norm_epsilon: number, - layer_norm_rms_epsilon: number - }, - - rope: { - dimension_count: number, - freq_base: number, - scaling: { - type: string, - factor: number, - original_context_length: number, - finetuned: boolean - } - } -}; - -export type GgufMetadataGeneralProperties = { - architecture: GgufArchitectureType, - - /** - * The version of the quantization format. Not required if the model is not - * quantized (i.e. no tensors are quantized). If any tensors are quantized, - * this must be present. This is separate to the quantization scheme of the - * tensors itself; the quantization version may change without changing the - * scheme's name (e.g. the quantization scheme is Q5_K, and the quantization - * version is 4). - */ - quantization_version: string, - - /** - * the global alignment to use, as described above. This can vary to allow - * for different alignment schemes, but it must be a multiple of 8. Some - * writers may not write the alignment. If the alignment is not specified, - * assume it is `32`. - */ - alignment: string, - - /** - * The name of the model. This should be a human-readable name that can be - * used to identify the model. It should be unique within the community - * that the model is defined in. - */ - name: string, - author: string, - - /** - * URL to the model's homepage. This can be a GitHub repo, a paper, etc. - */ - url: string, - - /** - * free-form description of the model including anything that isn't - * covered by the other fields - */ - description: string, - - /** - * License of the model, expressed as a SPDX license expression - * (e.g. `MIT OR Apache-2.0`). *Should not* include any other information, - * such as the license text or the URL to the license. - */ - license: string, - - /** - * Information about where this model came from. This is useful for tracking - * the provenance of the model, and for finding the original source if the - * model is modified. For a model that was converted from GGML, for - * example, these keys would point to the model that was converted from. - */ - source: { - /** - * URL to the source of the model. Can be a GitHub repo, a paper, etc. - */ - url: string, - huggingface: { - repository: string - } - }, - - /** - * An enumerated value describing the type of the majority of the tensors - * in the file. Optional; can be inferred from the tensor types. - */ - file_type?: GgufFileType | undefined -}; - -export type GgufMetadataAny = { - general: GgufMetadataGeneralProperties -} & { - [key in GgufArchitectureType]: GgufMetadataArchitectureProperties -}; - -export type GgufMetadataLLAMA = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.llama - }, - - llama: { - context_length: number, - embedding_length: number, - block_count: number, - feed_forward_length: number, - attention: { - head_count: number, - layer_norm_rms_epsilon: number, - head_count_kv?: number - }, - rope: { - dimension_count: number, - scale?: number - }, - expert_count?: number, - expert_used_count?: number, - tensor_data_layout?: string - } -}; - -export type GgufMetadataFalcon = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.falcon - }, - - falcon: { - context_length: number, - embedding_length: number, - block_count: number, - attention: { - head_count: number, - head_count_kv: number, - use_norm: boolean, - layer_norm_epsilon: number - }, - tensor_data_layout?: string - } -}; - -export type GgufMetadataMPT = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.mpt - }, - - mpt: { - context_length: number, - embedding_length: number, - block_count: number, - attention: { - head_count: number, - alibi_bias_max: number, - clip_kqv: number, - layer_norm_epsilon: number - } - } -}; - -export type GgufMetadataGPTNeoX = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.gptneox - }, - - gptneox: { - context_length: number, - embedding_length: number, - block_count: number, - use_parallel_residual: boolean, - rope: { - dimension_count: number, - freq_base: number, - scale?: number - }, - attention: { - head_count: number, - layer_norm_epsilon: number - } - } -}; - -export type GgufMetadataGPTJ = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.gptj - }, - - gptj: { - context_length: number, - embedding_length: number, - block_count: number, - rope: { - dimension_count: number, - scale?: number - }, - attention: { - head_count: number, - layer_norm_epsilon: number - } - } -}; - -export type GgufMetadataGPT2 = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.gpt2 - }, - - gpt2: { - context_length: number, - embedding_length: number, - block_count: number, - attention: { - head_count: number, - layer_norm_epsilon: number - } - } -}; - -export type GgufMetadataBloom = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.bloom - }, - - bloom: { - context_length: number, - embedding_length: number, - block_count: number, - feed_forward_length: number, - attention: { - head_count: number, - layer_norm_epsilon: number - } - } -}; - -export type GgufMetadataRWKV = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.rwkv - }, - - rwkv: { - context_length: number, - block_count: number, - embedding_length: number, - feed_forward_length: number - } -}; - -export type GgufMetadataWhisper = { - general: GgufMetadataGeneralProperties & { - architecture: GgufArchitectureType.whisper - }, - whisper: { - encoder: { - context_length: number, - embedding_length: number, - block_count: number, - mels_count: number, - attention: { - head_count: number - } - }, - decoder: { - context_length: number, - embedding_length: number, - block_count: number, - attention: { - head_count: number - } - } - } -}; - -export type GgufMetadata = - | GgufMetadataLLAMA - | GgufMetadataFalcon - | GgufMetadataMPT - | GgufMetadataGPTNeoX - | GgufMetadataGPTJ - | GgufMetadataGPT2 - | GgufMetadataBloom - | GgufMetadataRWKV - | GgufMetadataWhisper; - - -export function isLlamaMetadata(metadata: GgufMetadata): metadata is GgufMetadataLLAMA { - return metadata.general.architecture === GgufArchitectureType.llama; -} - -export function isMPTMetadata(metadata: GgufMetadata): metadata is GgufMetadataMPT { - return metadata.general.architecture === GgufArchitectureType.mpt; -} - -export function isGPTNeoXMetadata(metadata: GgufMetadata): metadata is GgufMetadataGPTNeoX { - return metadata.general.architecture === GgufArchitectureType.gptneox; -} - -export function isGPTJMetadata(metadata: GgufMetadata): metadata is GgufMetadataGPTJ { - return metadata.general.architecture === GgufArchitectureType.gptj; -} - -export function isGPT2Metadata(metadata: GgufMetadata): metadata is GgufMetadataGPT2 { - return metadata.general.architecture === GgufArchitectureType.gpt2; -} - -export function isBloomMetadata(metadata: GgufMetadata): metadata is GgufMetadataBloom { - return metadata.general.architecture === GgufArchitectureType.bloom; -} - -export function isFalconMetadata(metadata: GgufMetadata): metadata is GgufMetadataFalcon { - return metadata.general.architecture === GgufArchitectureType.falcon; -} - -export function isRWKVMetadata(metadata: GgufMetadata): metadata is GgufMetadataRWKV { - return metadata.general.architecture === GgufArchitectureType.rwkv; -} diff --git a/src/gguf/ggufParser/GgufParser.ts b/src/gguf/ggufParser/GgufParser.ts index 5c32b91b..091f4fab 100644 --- a/src/gguf/ggufParser/GgufParser.ts +++ b/src/gguf/ggufParser/GgufParser.ts @@ -1,12 +1,14 @@ import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; -import UnsupportedGgufMetadataTypeError from "../errors/UnsupportedMetadataTypeError.js"; +import {UnsupportedGgufValueTypeError} from "../errors/UnsupportedGgufValueTypeError.js"; import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; import {GgufReadOffset} from "./utils/GgufReadOffset.js"; -import {parseGgufFileTypeNumber} from "./utils/parseGgufFileTypeNumber.js"; -import {GgufMetadataAny} from "./GgufMetadataTypes.js"; +import {GgufMetadata} from "./types/GgufMetadataTypes.js"; import {GgufFileReader, valueTypeToBytesToRead} from "./fileReaders/GgufFileReader.js"; +import {GgufFileInfo} from "./types/GgufFileInfoTypes.js"; +import {GgmlType, GgufTensorInfo} from "./types/GgufTensorInfoTypes.js"; -const enum MetadataValueType { +// source: `enum gguf_type` in `ggml.h` in the `llama.cpp` source code +const enum GgufValueType { Uint8 = 0, Int8 = 1, Uint16 = 2, @@ -16,44 +18,42 @@ const enum MetadataValueType { Float32 = 6, Bool = 7, String = 8, - Array = 9 + Array = 9, + Uint64 = 10, + Int64 = 11, + Float64 = 12 } const ggufMagic = "GGUF"; -// these keys are ignored by default because they contain very long values that aren't very useful in the JS side of this library -const defaultIgnoreMetadataKeys = [ - "tokenizer.ggml.tokens", - "tokenizer.ggml.scores", - "tokenizer.ggml.token_type", - "tokenizer.ggml.merges" -]; - -export type GgufParsedMetadataResult = { - metadataSize: number, - metadata: GgufMetadataAny -}; export type GgufParserOptions = { fileReader: GgufFileReader, + readTensorInfo?: boolean, ignoreKeys?: string[] }; export class GgufParser { private readonly _fileReader: GgufFileReader; - public readonly ignoreKeys = defaultIgnoreMetadataKeys; + private readonly _readTensorInfo: boolean; + private readonly _ignoreKeys: string[]; - public constructor({fileReader, ignoreKeys = defaultIgnoreMetadataKeys}: GgufParserOptions) { - this.ignoreKeys = ignoreKeys; + public constructor({fileReader, readTensorInfo = true, ignoreKeys = []}: GgufParserOptions) { this._fileReader = fileReader; + this._readTensorInfo = readTensorInfo; + this._ignoreKeys = ignoreKeys; } - public async parseMetadata({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { - const metadataRaw = await this._parseMetadataRaw(); + public async parseFileInfo({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { + const readOffset = new GgufReadOffset(0); + const headerReadResult = await this._parseHeaderRaw(readOffset); + const tensorReadResult = this._readTensorInfo + ? await this._parseTensorInfo(headerReadResult.tensorCount, readOffset) + : null; const metadata: { [key: string]: any } = {}; - for (const [key, value] of Object.entries(metadataRaw.metadata)) { - if (this.ignoreKeys.includes(key)) + for (const [key, value] of Object.entries(headerReadResult.metadata)) { + if (this._ignoreKeys.includes(key)) continue; const {lastObject, lastKey} = GgufParser._getNestedObject(key, metadata); @@ -63,49 +63,51 @@ export class GgufParser { lastObject[lastKey] = value; } - if (typeof metadata?.general?.file_type === "number") { - metadata.general["file_type"] = parseGgufFileTypeNumber(metadata.general.file_type) || metadata.general.file_type; - } - return { - metadata: metadata as GgufMetadataAny, - metadataSize: metadataRaw.metadataSize + version: headerReadResult.version, + tensorCount: headerReadResult.tensorCount, + metadata: metadata as GgufMetadata, + tensorInfo: tensorReadResult?.tensorInfo, + metadataSize: headerReadResult.metadataSize, + tensorInfoSize: tensorReadResult?.tensorInfoSize }; } - private async _readMetadataValue(type: MetadataValueType, offset: number | GgufReadOffset): Promise { + private async _readGgufValue(type: GgufValueType, offset: number | GgufReadOffset): Promise { const readOffset = GgufReadOffset.resolveReadOffset(offset); switch (type) { - case MetadataValueType.Uint8: return await this._fileReader.readUint8(readOffset); - case MetadataValueType.Int8: return await this._fileReader.readInt8(readOffset); - case MetadataValueType.Uint16: return await this._fileReader.readUint16(readOffset); - case MetadataValueType.Int16: return await this._fileReader.readInt16(readOffset); - case MetadataValueType.Uint32: return await this._fileReader.readUint32(readOffset); - case MetadataValueType.Int32: return await this._fileReader.readInt32(readOffset); - case MetadataValueType.Float32: return await this._fileReader.readFloat32(readOffset); - case MetadataValueType.Bool: return await this._fileReader.readBool(readOffset); - case MetadataValueType.String: return await this._fileReader.readString(readOffset); + case GgufValueType.Uint8: return await this._fileReader.readUint8(readOffset); + case GgufValueType.Int8: return await this._fileReader.readInt8(readOffset); + case GgufValueType.Uint16: return await this._fileReader.readUint16(readOffset); + case GgufValueType.Int16: return await this._fileReader.readInt16(readOffset); + case GgufValueType.Uint32: return await this._fileReader.readUint32(readOffset); + case GgufValueType.Int32: return await this._fileReader.readInt32(readOffset); + case GgufValueType.Float32: return await this._fileReader.readFloat32(readOffset); + case GgufValueType.Bool: return await this._fileReader.readBool(readOffset); + case GgufValueType.String: return await this._fileReader.readString(readOffset); + case GgufValueType.Uint64: return await this._fileReader.readUint64(readOffset); + case GgufValueType.Int64: return await this._fileReader.readInt64(readOffset); + case GgufValueType.Float64: return await this._fileReader.readFloat64(readOffset); } - if (type === MetadataValueType.Array) { + if (type === GgufValueType.Array) { const arrayType = await this._fileReader.readUint32(readOffset); const arrayLength = await this._fileReader.readUint64(readOffset); const arrayValues: any[] = []; for (let i = 0; i < arrayLength; i++) { - const value = await this._readMetadataValue(arrayType, readOffset); + const value = await this._readGgufValue(arrayType, readOffset); arrayValues.push(value); } return arrayValues; } - throw new UnsupportedGgufMetadataTypeError(type); + throw new UnsupportedGgufValueTypeError(type); } - private async _parseMetadataRaw(): Promise<{metadata: Record, metadataSize: number}> { - const readOffset = new GgufReadOffset(0); - + private async _parseHeaderRaw(readOffset: GgufReadOffset) { + const initialOffset = readOffset.offset; const fileMagicBytes = await this._fileReader.readByteRange(readOffset, valueTypeToBytesToRead.uint8 * ggufMagic.length); const fileMagicText = String.fromCharCode(...fileMagicBytes); @@ -116,20 +118,50 @@ export class GgufParser { const tensorCount = await this._fileReader.readUint64(readOffset); const metadataKVCount = Number(await this._fileReader.readUint64(readOffset)); - const metadata: { [key: string]: any } = { - version, - tensorCount: GgufFileReader.castNumber(tensorCount) - }; + const metadata: Record = {}; for (let i = 0; i < metadataKVCount; i++) { const keyResult = await this._fileReader.readString(readOffset); const valueType = await this._fileReader.readUint32(readOffset); - metadata[keyResult] = await this._readMetadataValue(valueType, readOffset); + metadata[keyResult] = await this._readGgufValue(valueType, readOffset); } return { + version, + tensorCount: GgufFileReader.castNumber(tensorCount), metadata: metadata, - metadataSize: readOffset.offset + metadataSize: readOffset.offset - initialOffset + }; + } + + private async _parseTensorInfo(tensorCount: number | bigint, readOffset: GgufReadOffset) { + const initialOffset = readOffset.offset; + const tensorInfo: GgufTensorInfo[] = []; + + for (let i = 0n; i < BigInt(tensorCount); i++) { + const name = await this._fileReader.readString(readOffset); + const dimensionsNumber = await this._fileReader.readUint32(readOffset); + const dimensions: (number | bigint)[] = []; + + for (let i = 0; i < dimensionsNumber; i++) { + const dimension = await this._fileReader.readUint64(readOffset); + dimensions.push(GgufFileReader.castNumber(dimension)); + } + + const ggmlType = await this._fileReader.readUint32(readOffset); + const offset = await this._fileReader.readUint64(readOffset); + + tensorInfo.push({ + name, + dimensions, + ggmlType: ggmlType as GgmlType, + offset: GgufFileReader.castNumber(offset) + }); + } + + return { + tensorInfo, + tensorInfoSize: readOffset.offset - initialOffset }; } diff --git a/src/gguf/ggufParser/types/GgufFileInfoTypes.ts b/src/gguf/ggufParser/types/GgufFileInfoTypes.ts new file mode 100644 index 00000000..129722bd --- /dev/null +++ b/src/gguf/ggufParser/types/GgufFileInfoTypes.ts @@ -0,0 +1,15 @@ +import {GgufMetadata} from "./GgufMetadataTypes.js"; +import {GgufTensorInfo} from "./GgufTensorInfoTypes.js"; + +export type GgufFileInfo = { + version: 3 | number, + tensorCount: number | bigint, + metadata: GgufMetadata, + metadataSize: number, + + /** can be null if `readTensorInfo` is set to `false` */ + tensorInfo?: GgufTensorInfo[], + + /** can be null if `readTensorInfo` is set to `false` */ + tensorInfoSize?: number +}; diff --git a/src/gguf/ggufParser/types/GgufMetadataTypes.ts b/src/gguf/ggufParser/types/GgufMetadataTypes.ts new file mode 100644 index 00000000..f26145ca --- /dev/null +++ b/src/gguf/ggufParser/types/GgufMetadataTypes.ts @@ -0,0 +1,422 @@ +export const enum GgufArchitectureType { + llama = "llama", + falcon = "falcon", + gpt2 = "gpt2", + gptj = "gptj", + gptneox = "gptneox", + mpt = "mpt", + baichuan = "baichuan", + starcoder = "starcoder", + persimmon = "persimmon", + refact = "refact", + bert = "bert", + nomicBert = "nomic-bert", + bloom = "bloom", + stablelm = "stablelm", + qwen = "qwen", + qwen2 = "qwen2", + phi2 = "phi2", + plamo = "plamo", + codeshell = "codeshell", + orion = "orion", + internlm2 = "internlm2", + minicpm = "minicpm", + gemma = "gemma", + starcoder2 = "starcoder2", + mamba = "mamba", + commandR = "command-r", + rwkv = "rwkv" +} + +export type GgufMetadata = { + general: GgufMetadataGeneral, + tokenizer: GgufMetadataTokenizer +} & ( + GgufArchitectureType extends A ? { + [key in GgufArchitectureType]?: key extends keyof GgufMetadataLlmToType + ? GgufMetadataLlmToType[key] + : GgufMetadataLlmDefaultArchitectureType + } + : { + [key in A]: key extends keyof GgufMetadataLlmToType + ? GgufMetadataLlmToType[key] + : GgufMetadataLlmDefaultArchitectureType + } +); + + +export type GgufMetadataLlmToType = { + [GgufArchitectureType.llama]: GgufMetadataLlmLLaMA, + [GgufArchitectureType.mpt]: GgufMetadataMPT, + [GgufArchitectureType.gptneox]: GgufMetadataGPTNeoX, + [GgufArchitectureType.gptj]: GgufMetadataGPTJ, + [GgufArchitectureType.gpt2]: GgufMetadataGPT2, + [GgufArchitectureType.bloom]: GgufMetadataBloom, + [GgufArchitectureType.falcon]: GgufMetadataFalcon, + [GgufArchitectureType.mamba]: GgufMetadataMamba, + [GgufArchitectureType.rwkv]: GgufMetadataRWKV +}; + +// source: `enum llama_ftype` in `llama.h` in the `llama.cpp` source code +export enum GgufFileType { + ALL_F32 = 0, + MOSTLY_F16 = 1, + MOSTLY_Q4_0 = 2, + MOSTLY_Q4_1 = 3, + MOSTLY_Q4_1_SOME_F16 = 4, + MOSTLY_Q4_2 = 5, + MOSTLY_Q4_3 = 6, + MOSTLY_Q8_0 = 7, + MOSTLY_Q5_0 = 8, + MOSTLY_Q5_1 = 9, + MOSTLY_Q2_K = 10, + MOSTLY_Q3_K_S = 11, + MOSTLY_Q3_K_M = 12, + MOSTLY_Q3_K_L = 13, + MOSTLY_Q4_K_S = 14, + MOSTLY_Q4_K_M = 15, + MOSTLY_Q5_K_S = 16, + MOSTLY_Q5_K_M = 17, + MOSTLY_Q6_K = 18, + MOSTLY_IQ2_XXS = 19, + MOSTLY_IQ2_XS = 20, + MOSTLY_Q2_K_S = 21, + MOSTLY_IQ3_XS = 22, + MOSTLY_IQ3_XXS = 23, + MOSTLY_IQ1_S = 24, + MOSTLY_IQ4_NL = 25, + MOSTLY_IQ3_S = 26, + MOSTLY_IQ3_M = 27, + MOSTLY_IQ2_S = 28, + MOSTLY_IQ2_M = 29, + MOSTLY_IQ4_XS = 30 +} + + +export type GgufMetadataGeneral = { + architecture: A, + + /** + * The version of the quantization format. Not required if the model is not + * quantized (i.e. no tensors are quantized). If any tensors are quantized, + * this must be present. This is separate to the quantization scheme of the + * tensors itself; the quantization version may change without changing the + * scheme's name (e.g. the quantization scheme is Q5_K, and the quantization + * version is 4). + */ + quantization_version: string, + + /** + * the global alignment to use, as described above. This can vary to allow + * for different alignment schemes, but it must be a multiple of 8. Some + * writers may not write the alignment. If the alignment is not specified, + * assume it is `32`. + */ + alignment?: string, + + /** + * The name of the model. This should be a human-readable name that can be + * used to identify the model. It should be unique within the community + * that the model is defined in. + */ + name?: string, + author?: string, + + /** + * URL to the model's homepage. This can be a GitHub repo, a paper, etc. + */ + url?: string, + + /** + * free-form description of the model including anything that isn't + * covered by the other fields + */ + description?: string, + + /** + * License of the model, expressed as a SPDX license expression + * (e.g. `MIT OR Apache-2.0`). *Should not* include any other information, + * such as the license text or the URL to the license. + */ + license?: string, + + /** + * Information about where this model came from. This is useful for tracking + * the provenance of the model, and for finding the original source if the + * model is modified. For a model that was converted from GGML, for + * example, these keys would point to the model that was converted from. + */ + source?: { + /** + * URL to the source of the model. Can be a GitHub repo, a paper, etc. + */ + url?: string, + huggingface?: { + repository?: string + } + }, + + /** + * An enumerated value describing the type of the majority of the tensors + * in the file. Optional; can be inferred from the tensor types. + */ + file_type?: GgufFileType | undefined +}; + +export const enum GgufMetadataTokenizerTokenType { + undefined = 0, + normal = 1, + unknown = 2, + control = 3, + userDefined = 4, + unused = 5, + byte = 6 +} + +export type GgufMetadataTokenizer = { + ggml: { + model: "no_vocab" | "llama" | "gpt2" | "bert" | "replit" | "rwkv" | string, + tokens: string[], + token_type: GgufMetadataTokenizerTokenType[], + token_type_count?: number, + scores?: number[], + merges?: string[], + bos_token_id?: number, + eos_token_id?: number, + unknown_token_id?: number, + separator_token_id?: number, + padding_token_id?: number, + add_bos_token?: boolean, + add_eos_token?: boolean, + add_space_prefix?: boolean, + added_tokens?: string[] + }, + huggingface?: { + json?: string + }, + chat_template?: string +}; + +export const enum GgufMetadataLlmPoolingType { + unspecified = -1, + none = 0, + mean = 1, + max = 2, +} + +export type GgufMetadataLlmDefaultArchitectureType = { + vocab_size?: number, + context_length?: number, + embedding_length?: number, + block_count?: number, + feed_forward_length?: number, + use_parallel_residual?: boolean, + tensor_data_layout?: string, + expert_count?: number, + expert_used_count?: number, + pooling_type?: GgufMetadataLlmPoolingType, + logit_scale?: number, + + attention?: { + head_count?: number, + head_count_kv?: number, + max_alibi_bias?: number, + clamp_kqv?: number, + layer_norm_epsilon?: number, + layer_norm_rms_epsilon?: number, + key_length?: number, + value_length?: number, + causal?: boolean + }, + + rope?: { + dimension_count?: number, + freq_base?: number, + scale_linear?: number, + scaling?: { + type?: "none" | "linear" | "yarn" | string, + factor?: number, + original_context_length?: number, + finetuned?: boolean + } + }, + + ssm?: { + conv_kernel?: number, + inner_size?: number, + state_size?: number, + time_step_rank?: number + } +}; + +// export type GgufMetadataLlmKeyTypes = { +// context_length: number, +// embedding_length: number, +// block_count: number, +// feed_forward_length: number, +// use_parallel_residual: boolean, +// tensor_data_layout: string, +// expert_count: number, +// expert_used_count: number, +// +// attention: { +// head_count: number, +// head_count_kv: number, +// max_alibi_bias: number, +// clamp_kqv: number, +// layer_norm_epsilon: number, +// layer_norm_rms_epsilon: number, +// key_length: number, +// value_length: number +// }, +// +// rope: { +// dimension_count: number, +// freq_base: number, +// scaling: { +// type: "none" | "linear" | "yarn" | string, +// factor: number, +// original_context_length: number, +// finetuned: boolean, +// scale_linear?: number +// } +// }, +// +// ssm: { +// conv_kernel: number, +// inner_size: number, +// state_size: number, +// time_step_rank: number +// } +// }; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#llama +export type GgufMetadataLlmLLaMA = { + context_length: number, + embedding_length: number, + block_count: number, + feed_forward_length: number, + attention: { + head_count: number, + layer_norm_rms_epsilon: number, + head_count_kv?: number + }, + rope: { + dimension_count: number, + scale?: number + }, + expert_count?: number, + expert_used_count?: number, + tensor_data_layout?: string +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#mpt +export type GgufMetadataMPT = { + context_length: number, + embedding_length: number, + block_count: number, + attention: { + head_count: number, + alibi_bias_max: number, + clip_kqv: number, + layer_norm_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gpt-neox +export type GgufMetadataGPTNeoX = { + context_length: number, + embedding_length: number, + block_count: number, + use_parallel_residual: boolean, + rope: { + dimension_count: number, + // freq_base: number, + scale?: number + }, + attention: { + head_count: number, + layer_norm_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gpt-j +export type GgufMetadataGPTJ = { + context_length: number, + embedding_length: number, + block_count: number, + rope: { + dimension_count: number, + scale?: number + }, + attention: { + head_count: number, + layer_norm_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#gpt-2 +export type GgufMetadataGPT2 = { + context_length: number, + embedding_length: number, + block_count: number, + attention: { + head_count: number, + layer_norm_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#bloom +export type GgufMetadataBloom = { + context_length: number, + embedding_length: number, + block_count: number, + feed_forward_length: number, + attention: { + head_count: number, + layer_norm_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#falcon +export type GgufMetadataFalcon = { + context_length: number, + embedding_length: number, + block_count: number, + attention: { + head_count: number, + head_count_kv: number, + use_norm: boolean, + layer_norm_epsilon: number + }, + tensor_data_layout?: string +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#mamba +export type GgufMetadataMamba = { + context_length: number, + embedding_length: number, + block_count: number, + ssm: { + conv_kernel: number, + inner_size: number, + state_size: number, + time_step_rank: number + }, + attention: { + layer_norm_rms_epsilon: number + } +}; + +// source: https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#rwkv +export type GgufMetadataRWKV = { + architecture_version: 4 | number, + context_length: number, + block_count: number, + embedding_length: number, + feed_forward_length: number +}; + +export function isGgufMetadataOfArchitectureType(metadata: GgufMetadata, type: A): metadata is GgufMetadata { + return metadata?.general?.architecture === type; +} diff --git a/src/gguf/ggufParser/types/GgufTensorInfoTypes.ts b/src/gguf/ggufParser/types/GgufTensorInfoTypes.ts new file mode 100644 index 00000000..082d0260 --- /dev/null +++ b/src/gguf/ggufParser/types/GgufTensorInfoTypes.ts @@ -0,0 +1,38 @@ +export type GgufTensorInfo = { + name: string, + dimensions: (number | bigint)[], + ggmlType: GgmlType, + offset: number | bigint +}; + +export const enum GgmlType { + F32 = 0, + F16 = 1, + Q4_0 = 2, + Q4_1 = 3, + Q4_2 = 4, + Q4_3 = 5, + Q5_0 = 6, + Q5_1 = 7, + Q8_0 = 8, + Q8_1 = 9, + Q2_K = 10, + Q3_K = 11, + Q4_K = 12, + Q5_K = 13, + Q6_K = 14, + Q8_K = 15, + IQ2_XXS = 16, + IQ2_XS = 17, + IQ3_XXS = 18, + IQ1_S = 19, + IQ4_NL = 20, + IQ3_S = 21, + IQ2_S = 22, + IQ4_XS = 23, + I8 = 24, + I16 = 25, + I32 = 26, + I64 = 27, + F64 = 28 +} diff --git a/src/gguf/ggufParser/utils/getGgufFileTypeName.ts b/src/gguf/ggufParser/utils/getGgufFileTypeName.ts new file mode 100644 index 00000000..2f8c134b --- /dev/null +++ b/src/gguf/ggufParser/utils/getGgufFileTypeName.ts @@ -0,0 +1,14 @@ +import {GgufFileType} from "../types/GgufMetadataTypes.js"; + +const fileTypeNumberToNameMap = new Map(); +for (const [key, value] of Object.entries(GgufFileType)) { + if (typeof value === "number") + fileTypeNumberToNameMap.set(value, key as keyof typeof GgufFileType); +} + +/** + * Convert a GGUF file type number to its corresponding type name + */ +export function getGgufFileTypeName(fileType?: number) { + return fileTypeNumberToNameMap.get(fileType!) ?? undefined; +} diff --git a/src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts b/src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts new file mode 100644 index 00000000..118cb104 --- /dev/null +++ b/src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts @@ -0,0 +1,10 @@ +import {GgufArchitectureType, GgufMetadata} from "../types/GgufMetadataTypes.js"; +import {MergeOptionalUnionTypes} from "../../../utils/mergeUnionTypes.js"; + +export function getGgufMetadataLlmData(ggufMetadata: GgufMetadata): ( + GgufArchitectureType extends T + ? MergeOptionalUnionTypes> + : GgufMetadata[T] +) { + return ggufMetadata[ggufMetadata.general.architecture] ?? {} as any; +} diff --git a/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts b/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts deleted file mode 100644 index 1baa7ef6..00000000 --- a/src/gguf/ggufParser/utils/parseGgufFileTypeNumber.ts +++ /dev/null @@ -1,34 +0,0 @@ -import {GgufFileType} from "../GgufMetadataTypes.js"; - -/** - * https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#general-metadata - * Convert file type from string to int - */ -export function parseGgufFileTypeNumber(fileType?: number) { - if (fileType == null) - return undefined; - - switch (fileType) { - case 0: return GgufFileType.ALL_F32; - case 1: return GgufFileType.MOSTLY_F16; - case 2: return GgufFileType.MOSTLY_Q4_0; - case 3: return GgufFileType.MOSTLY_Q4_1; - case 4: return GgufFileType.MOSTLY_Q4_1_SOME_F16; - case 5: return GgufFileType.MOSTLY_Q4_2; - case 6: return GgufFileType.MOSTLY_Q4_3; - case 7: return GgufFileType.MOSTLY_Q8_0; - case 8: return GgufFileType.MOSTLY_Q5_0; - case 9: return GgufFileType.MOSTLY_Q5_1; - case 10: return GgufFileType.MOSTLY_Q2_K; - case 11: return GgufFileType.MOSTLY_Q3_K_S; - case 12: return GgufFileType.MOSTLY_Q3_K_M; - case 13: return GgufFileType.MOSTLY_Q3_K_L; - case 14: return GgufFileType.MOSTLY_Q4_K_S; - case 15: return GgufFileType.MOSTLY_Q4_K_M; - case 16: return GgufFileType.MOSTLY_Q5_K_S; - case 17: return GgufFileType.MOSTLY_Q5_K_M; - case 18: return GgufFileType.MOSTLY_Q6_K; - } - - return undefined; -} diff --git a/src/utils/mergeUnionTypes.ts b/src/utils/mergeUnionTypes.ts index f6cc13ec..b1094eb1 100644 --- a/src/utils/mergeUnionTypes.ts +++ b/src/utils/mergeUnionTypes.ts @@ -9,5 +9,9 @@ type UnionToIntersection = ( type DistributeUnion = { [K in keyof U]: U[K] }; +type OptionalDistributeUnion = { + [K in keyof U]?: U[K] +}; export type MergeUnionTypes = DistributeUnion>; +export type MergeOptionalUnionTypes = OptionalDistributeUnion>; diff --git a/src/utils/prettyPrintObject.ts b/src/utils/prettyPrintObject.ts index 69947b2e..d02b3fca 100644 --- a/src/utils/prettyPrintObject.ts +++ b/src/utils/prettyPrintObject.ts @@ -1,22 +1,20 @@ import chalk from "chalk"; +import stripAnsi from "strip-ansi"; -type PrettyPrintObjectOptions = { +export type PrettyPrintObjectOptions = { maxArrayValues?: number, - useNumberGrouping?: boolean + useNumberGrouping?: boolean, + maxArrayItemsWidth?: number, + + // `true` by default + multilineObjects?: boolean }; export function prettyPrintObject(obj: any, indent: number = 4, options: PrettyPrintObjectOptions = {}): string { if (typeof obj === "string") return chalk.green(JSON.stringify(obj, null, 4)); - else if (typeof obj === "number") - return chalk.yellow( - options.useNumberGrouping - ? obj.toLocaleString("en-US", { - style: "decimal", - useGrouping: true - }).replaceAll(",", "_") - : obj - ); + else if (typeof obj === "number" || typeof obj === "bigint") + return chalk.yellow(formatNumber(obj, {useNumberGrouping: options.useNumberGrouping})); else if (typeof obj === "boolean") return chalk.magenta.italic(obj); else if (obj === null) @@ -26,25 +24,31 @@ export function prettyPrintObject(obj: any, indent: number = 4, options: PrettyP else if (obj instanceof Array) return prettyPrintArray(obj, indent, options); + const nl = options.multilineObjects ?? true; const rows: string[] = []; for (const key of Object.keys(obj)) { const value = obj[key as keyof typeof obj]; rows.push([ - " ".repeat(indent), + (nl ? " ".repeat(indent) : ""), canStringBeKeyWithoutQuotes(key) ? chalk.red(key) : chalk.green(JSON.stringify(key)), chalk.whiteBright(": "), prettyPrintObject(value, indent, options) - .replaceAll("\n", "\n" + " ".repeat(indent)) + .replaceAll("\n", "\n" + (nl ? " ".repeat(indent) : "")) ].join("")); } if (rows.length === 0) return chalk.whiteBright("{}"); - return chalk.whiteBright("{\n") + rows.join(chalk.whiteBright(",\n")) + chalk.whiteBright("\n") + chalk.whiteBright("}"); + return [ + chalk.whiteBright("{" + (nl ? "\n" : "")), + rows.join(chalk.whiteBright("," + (nl ? "\n" : " "))), + (nl ? "\n" : ""), + chalk.whiteBright("}") + ].join(""); } function canStringBeKeyWithoutQuotes(key: string): boolean { @@ -57,15 +61,48 @@ function prettyPrintArray(arr: any[], indent: number = 4, options: PrettyPrintOb : arr; const hiddenItems = arr.length - slicedArray.length; + const arrayItems = slicedArray.map((item) => prettyPrintObject(item, indent, options)) + .concat( + hiddenItems > 0 + ? [chalk.white("..." + hiddenItems + " more item" + (hiddenItems !== 1 ? "s" : ""))] + : [] + ); + const oneLineJoinedArrayItems = arrayItems.join(chalk.whiteBright(", ")); + + if (options.maxArrayItemsWidth != null && + ("[".length + stripAnsi(oneLineJoinedArrayItems).length + "]".length) > options.maxArrayItemsWidth + ) { + return [ + chalk.whiteBright("["), + "\n", + " ".repeat(indent), + arrayItems + .join(chalk.whiteBright(",") + "\n") + .replaceAll("\n", "\n" + " ".repeat(indent)), + "\n", + chalk.whiteBright("]") + ].join(""); + } + return [ chalk.whiteBright("["), - slicedArray.map((item) => prettyPrintObject(item, indent, options)) - .concat( - hiddenItems > 0 - ? [chalk.white("..." + hiddenItems + " more item" + (hiddenItems !== 1 ? "s" : ""))] - : [] - ) - .join(chalk.whiteBright(", ")), + oneLineJoinedArrayItems, chalk.whiteBright("]") ].join(""); } + +export function formatNumber(num: number | bigint, {useNumberGrouping = false}: {useNumberGrouping?: boolean} = {}): string { + let res = useNumberGrouping + ? num + .toLocaleString("en-US", { + style: "decimal", + useGrouping: true + }) + .replaceAll(",", "_") + : String(num); + + if (typeof num === "bigint") + res += "n"; + + return res; +} diff --git a/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap b/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap index 21411c0c..ffd58170 100644 --- a/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap +++ b/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap @@ -5,7 +5,7 @@ exports[`GGUF Parser > should fetch GGUF metadata 1`] = ` "metadata": { "general": { "architecture": "llama", - "file_type": "MOSTLY_Q4_0", + "file_type": 2, "name": "workspace", "quantization_version": 2, }, @@ -24,7 +24,6 @@ exports[`GGUF Parser > should fetch GGUF metadata 1`] = ` "freq_base": 10000, }, }, - "tensorCount": 291, "tokenizer": { "chat_template": "{% for message in messages %} {% if message['role'] == 'user' or message['role'] == 'system' %} @@ -63,12 +62,87 @@ exports[`GGUF Parser > should fetch GGUF metadata 1`] = ` "eos_token_id": 2, "model": "llama", "padding_token_id": 2, + "scores": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "token_type": [ + 2, + 3, + 3, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + ], + "tokens": [ + "", + "", + "", + "<0x00>", + "<0x01>", + "<0x02>", + "<0x03>", + "<0x04>", + "<0x05>", + "<0x06>", + ], "unknown_token_id": 0, }, }, - "version": 3, }, "metadataSize": 718762, + "tensorCount": 291, + "tensorInfo": [ + { + "dimensions": [ + 4096, + 32004, + ], + "ggmlType": 2, + "name": "token_embd.weight", + "offset": 0, + }, + { + "dimensions": [ + 4096, + ], + "ggmlType": 0, + "name": "blk.0.attn_norm.weight", + "offset": 73737216, + }, + { + "dimensions": [ + 14336, + 4096, + ], + "ggmlType": 2, + "name": "blk.0.ffn_down.weight", + "offset": 73753600, + }, + { + "dimensions": [ + 4096, + 14336, + ], + "ggmlType": 2, + "name": "blk.0.ffn_gate.weight", + "offset": 106783744, + }, + ], + "tensorInfoSize": 17286, + "version": 3, } `; @@ -77,7 +151,7 @@ exports[`GGUF Parser > should parse local gguf model 1`] = ` "metadata": { "general": { "architecture": "llama", - "file_type": "MOSTLY_Q4_0", + "file_type": 2, "name": "workspace", "quantization_version": 2, }, @@ -96,7 +170,6 @@ exports[`GGUF Parser > should parse local gguf model 1`] = ` "freq_base": 10000, }, }, - "tensorCount": 291, "tokenizer": { "chat_template": "{% for message in messages %} {% if message['role'] == 'user' or message['role'] == 'system' %} @@ -135,11 +208,86 @@ exports[`GGUF Parser > should parse local gguf model 1`] = ` "eos_token_id": 2, "model": "llama", "padding_token_id": 2, + "scores": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "token_type": [ + 2, + 3, + 3, + 6, + 6, + 6, + 6, + 6, + 6, + 6, + ], + "tokens": [ + "", + "", + "", + "<0x00>", + "<0x01>", + "<0x02>", + "<0x03>", + "<0x04>", + "<0x05>", + "<0x06>", + ], "unknown_token_id": 0, }, }, - "version": 3, }, "metadataSize": 718762, + "tensorCount": 291, + "tensorInfo": [ + { + "dimensions": [ + 4096, + 32004, + ], + "ggmlType": 2, + "name": "token_embd.weight", + "offset": 0, + }, + { + "dimensions": [ + 4096, + ], + "ggmlType": 0, + "name": "blk.0.attn_norm.weight", + "offset": 73737216, + }, + { + "dimensions": [ + 14336, + 4096, + ], + "ggmlType": 2, + "name": "blk.0.ffn_down.weight", + "offset": 73753600, + }, + { + "dimensions": [ + 4096, + 14336, + ], + "ggmlType": 2, + "name": "blk.0.ffn_gate.weight", + "offset": 106783744, + }, + ], + "tensorInfoSize": 17286, + "version": 3, } `; diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index ab886a89..68628895 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -4,7 +4,8 @@ import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; -import {parseGgufMetadata} from "../../../src/gguf/parseGgufMetadata.js"; +import {getGgufFileInfo} from "../../../src/gguf/getGgufFileInfo.js"; +import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; describe("GGUF Parser", async () => { const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); @@ -19,18 +20,21 @@ describe("GGUF Parser", async () => { it("should parse local gguf model", async () => { const fileReader = new GgufFsFileReader({filePath: modelPath}); - const ggufParser = new GgufParser({fileReader: fileReader}); + const ggufParser = new GgufParser({ + fileReader: fileReader + }); - const metadata = await ggufParser.parseMetadata(); + const metadata = await ggufParser.parseFileInfo(); + expect(metadata.tensorInfo!.length).to.be.eql(Number(metadata.tensorCount)); - expect(metadata).toMatchSnapshot(); + expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); }); it("should calculate GGUF VRAM Usage", async () => { const fileReader = new GgufFsFileReader({filePath: modelPath}); const ggufParser = new GgufParser({fileReader: fileReader}); - const metadata = await ggufParser.parseMetadata(); + const metadata = await ggufParser.parseFileInfo(); const ggufInsights = new GgufInsights(metadata); @@ -48,9 +52,9 @@ describe("GGUF Parser", async () => { }); it("should fetch GGUF metadata", async () => { - const ggufMetadataParseResult = await parseGgufMetadata(modelPath); + const ggufMetadataParseResult = await getGgufFileInfo(modelPath); - expect(ggufMetadataParseResult).toMatchSnapshot(); + expect(simplifyGgufInfoForTestSnapshot(ggufMetadataParseResult)).toMatchSnapshot(); const insights = new GgufInsights(ggufMetadataParseResult); expect(insights.VRAMUsage).toMatchInlineSnapshot("4474643028.666667"); diff --git a/test/standalone/gguf/__snapshots__/gguf.test.ts.snap b/test/standalone/gguf/__snapshots__/gguf.test.ts.snap index f91e892d..fb6b6e60 100644 --- a/test/standalone/gguf/__snapshots__/gguf.test.ts.snap +++ b/test/standalone/gguf/__snapshots__/gguf.test.ts.snap @@ -17,19 +17,190 @@ exports[`GGUF Parser > should parse remote gguf model 1`] = ` }, "general": { "architecture": "falcon", - "file_type": "MOSTLY_Q6_K", + "file_type": 18, "name": "Falcon", "quantization_version": 2, }, - "tensorCount": 644, "tokenizer": { "ggml": { "eos_token_id": 11, + "merges": [ + "Ä  t", + "Ä  a", + "i n", + "h e", + "r e", + "o n", + "e r", + "Ä  s", + "Ä t he", + "a t", + ], "model": "gpt2", + "scores": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "token_type": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "tokens": [ + ">>TITLE<<", + ">>ABSTRACT<<", + ">>INTRODUCTION<<", + ">>SUMMARY<<", + ">>COMMENT<<", + ">>ANSWER<<", + ">>QUESTION<<", + ">>DOMAIN<<", + ">>PREFIX<<", + ">>SUFFIX<<", + ], }, }, - "version": 2, }, "metadataSize": 2547826, + "tensorCount": 644, + "tensorInfo": [ + { + "dimensions": [ + 14848, + 59392, + ], + "ggmlType": 14, + "name": "blk.0.ffn_up.weight", + "offset": 0, + }, + { + "dimensions": [ + 14848, + 14848, + ], + "ggmlType": 14, + "name": "blk.0.attn_output.weight", + "offset": 723394560, + }, + { + "dimensions": [ + 14848, + 15872, + ], + "ggmlType": 14, + "name": "blk.0.attn_qkv.weight", + "offset": 904243200, + }, + { + "dimensions": [ + 14848, + 65024, + ], + "ggmlType": 14, + "name": "token_embd.weight", + "offset": 1097564160, + }, + ], + "tensorInfoSize": 37648, + "version": 2, +} +`; + +exports[`GGUF Parser > should parse remote gguf model without tensor info 1`] = ` +{ + "metadata": { + "falcon": { + "attention": { + "head_count": 232, + "head_count_kv": 8, + "layer_norm_epsilon": 0.000009999999747378752, + }, + "block_count": 80, + "context_length": 2048, + "embedding_length": 14848, + "feed_forward_length": 59392, + "tensor_data_layout": "jploski", + }, + "general": { + "architecture": "falcon", + "file_type": 18, + "name": "Falcon", + "quantization_version": 2, + }, + "tokenizer": { + "ggml": { + "eos_token_id": 11, + "merges": [ + "Ä  t", + "Ä  a", + "i n", + "h e", + "r e", + "o n", + "e r", + "Ä  s", + "Ä t he", + "a t", + ], + "model": "gpt2", + "scores": [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ], + "token_type": [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + 1, + ], + "tokens": [ + ">>TITLE<<", + ">>ABSTRACT<<", + ">>INTRODUCTION<<", + ">>SUMMARY<<", + ">>COMMENT<<", + ">>ANSWER<<", + ">>QUESTION<<", + ">>DOMAIN<<", + ">>PREFIX<<", + ">>SUFFIX<<", + ], + }, + }, + }, + "metadataSize": 2547826, + "tensorCount": 644, + "tensorInfo": undefined, + "tensorInfoSize": undefined, + "version": 2, } `; diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 95887331..724cb0cd 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,6 +1,7 @@ import {describe, expect, it, test} from "vitest"; import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; import {GgufNetworkFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; +import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; @@ -16,10 +17,24 @@ describe("GGUF Parser", async () => { it("should parse remote gguf model", async () => { const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - const ggufParser = new GgufParser({fileReader: fileReader}); + const ggufParser = new GgufParser({ + fileReader: fileReader + }); - const metadata = await ggufParser.parseMetadata(); + const metadata = await ggufParser.parseFileInfo(); - expect(metadata).toMatchSnapshot(); + expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); + }); + + it("should parse remote gguf model without tensor info", async () => { + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); + const ggufParser = new GgufParser({ + fileReader: fileReader, + readTensorInfo: false + }); + + const metadata = await ggufParser.parseFileInfo(); + + expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); }); }); diff --git a/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts new file mode 100644 index 00000000..f09299f7 --- /dev/null +++ b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts @@ -0,0 +1,23 @@ +import {GgufFileInfo} from "../../../src/gguf/ggufParser/types/GgufFileInfoTypes.js"; + +export function simplifyGgufInfoForTestSnapshot(ggufFileInfo: GgufFileInfo) { + const ggufFileInfoCopy = structuredClone(ggufFileInfo); + + // these keys are ignored in tests because they contain very long values, so we don't want to include them in full, + // to make sure we won't make the test snapshots huge, keeping them readable and maintainable + shortenArray(ggufFileInfoCopy.metadata.tokenizer.ggml.tokens, 10); + shortenArray(ggufFileInfoCopy.metadata.tokenizer.ggml.scores, 10); + shortenArray(ggufFileInfoCopy.metadata.tokenizer.ggml.token_type, 10); + shortenArray(ggufFileInfoCopy.metadata.tokenizer.ggml.merges, 10); + + shortenArray(ggufFileInfoCopy.tensorInfo, 4); + + return ggufFileInfoCopy; +} + +function shortenArray(array?: any[], maxSize: number = 10) { + if (array == null) + return; + + array.splice(maxSize); +} From 29c629c201ef2a7eabb41913dad7560a192efe30 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 02:57:32 +0200 Subject: [PATCH 14/52] style: lint fix --- src/gguf/ggufParser/types/GgufMetadataTypes.ts | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/gguf/ggufParser/types/GgufMetadataTypes.ts b/src/gguf/ggufParser/types/GgufMetadataTypes.ts index f26145ca..352a684f 100644 --- a/src/gguf/ggufParser/types/GgufMetadataTypes.ts +++ b/src/gguf/ggufParser/types/GgufMetadataTypes.ts @@ -417,6 +417,8 @@ export type GgufMetadataRWKV = { feed_forward_length: number }; -export function isGgufMetadataOfArchitectureType(metadata: GgufMetadata, type: A): metadata is GgufMetadata { +export function isGgufMetadataOfArchitectureType( + metadata: GgufMetadata, type: A +): metadata is GgufMetadata { return metadata?.general?.architecture === type; } From 24093b1aa6f3b422ef571c542b1014bfa0fb64a8 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 18:07:42 +0200 Subject: [PATCH 15/52] refactor: use a gguf version specific parser --- src/gguf/getGgufFileInfo.ts | 15 +- src/gguf/ggufParser/GgufParser.ts | 192 ------------------ .../ggufParser/fileReaders/GgufFileReader.ts | 16 +- src/gguf/ggufParser/parseGguf.ts | 79 +++++++ src/gguf/ggufParser/parsers/GgufV2Parser.ts | 142 +++++++++++++ .../ggufParser/types/GgufFileInfoTypes.ts | 46 ++++- ...ertMetadataKeyValueRecordToNestedObject.ts | 59 ++++++ test/modelDependent/functionary/gguf.test.ts | 12 +- test/standalone/gguf/gguf.test.ts | 12 +- 9 files changed, 348 insertions(+), 225 deletions(-) delete mode 100644 src/gguf/ggufParser/GgufParser.ts create mode 100644 src/gguf/ggufParser/parseGguf.ts create mode 100644 src/gguf/ggufParser/parsers/GgufV2Parser.ts create mode 100644 src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts diff --git a/src/gguf/getGgufFileInfo.ts b/src/gguf/getGgufFileInfo.ts index 137fcb18..be964666 100644 --- a/src/gguf/getGgufFileInfo.ts +++ b/src/gguf/getGgufFileInfo.ts @@ -1,5 +1,5 @@ import retry from "async-retry"; -import {GgufParser} from "./ggufParser/GgufParser.js"; +import {parseGguf} from "./ggufParser/parseGguf.js"; import {GgufNetworkFetchFileReader} from "./ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; import {ggufDefaultRetryOptions} from "./consts.js"; @@ -12,12 +12,14 @@ export async function getGgufFileInfo(pathOrUrl: string, { readTensorInfo = true, sourceType, retryOptions = ggufDefaultRetryOptions, - ignoreKeys = [] + ignoreKeys = [], + logWarnings = true }: { readTensorInfo?: boolean, sourceType?: "network" | "filesystem", retryOptions?: retry.Options, - ignoreKeys?: string[] + ignoreKeys?: string[], + logWarnings?: boolean } = {}) { function createFileReader() { if (sourceType === "network" || (sourceType == null && (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")))) { @@ -37,11 +39,10 @@ export async function getGgufFileInfo(pathOrUrl: string, { } const fileReader = createFileReader(); - const parser = new GgufParser({ + return await parseGguf({ fileReader, ignoreKeys, - readTensorInfo + readTensorInfo, + logWarnings }); - - return await parser.parseFileInfo(); } diff --git a/src/gguf/ggufParser/GgufParser.ts b/src/gguf/ggufParser/GgufParser.ts deleted file mode 100644 index 091f4fab..00000000 --- a/src/gguf/ggufParser/GgufParser.ts +++ /dev/null @@ -1,192 +0,0 @@ -import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; -import {UnsupportedGgufValueTypeError} from "../errors/UnsupportedGgufValueTypeError.js"; -import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; -import {GgufReadOffset} from "./utils/GgufReadOffset.js"; -import {GgufMetadata} from "./types/GgufMetadataTypes.js"; -import {GgufFileReader, valueTypeToBytesToRead} from "./fileReaders/GgufFileReader.js"; -import {GgufFileInfo} from "./types/GgufFileInfoTypes.js"; -import {GgmlType, GgufTensorInfo} from "./types/GgufTensorInfoTypes.js"; - -// source: `enum gguf_type` in `ggml.h` in the `llama.cpp` source code -const enum GgufValueType { - Uint8 = 0, - Int8 = 1, - Uint16 = 2, - Int16 = 3, - Uint32 = 4, - Int32 = 5, - Float32 = 6, - Bool = 7, - String = 8, - Array = 9, - Uint64 = 10, - Int64 = 11, - Float64 = 12 -} - -const ggufMagic = "GGUF"; - - -export type GgufParserOptions = { - fileReader: GgufFileReader, - readTensorInfo?: boolean, - ignoreKeys?: string[] -}; - -export class GgufParser { - private readonly _fileReader: GgufFileReader; - private readonly _readTensorInfo: boolean; - private readonly _ignoreKeys: string[]; - - public constructor({fileReader, readTensorInfo = true, ignoreKeys = []}: GgufParserOptions) { - this._fileReader = fileReader; - this._readTensorInfo = readTensorInfo; - this._ignoreKeys = ignoreKeys; - } - - public async parseFileInfo({logWarnings = true}: {logWarnings?: boolean} = {}): Promise { - const readOffset = new GgufReadOffset(0); - const headerReadResult = await this._parseHeaderRaw(readOffset); - const tensorReadResult = this._readTensorInfo - ? await this._parseTensorInfo(headerReadResult.tensorCount, readOffset) - : null; - const metadata: { [key: string]: any } = {}; - - for (const [key, value] of Object.entries(headerReadResult.metadata)) { - if (this._ignoreKeys.includes(key)) - continue; - - const {lastObject, lastKey} = GgufParser._getNestedObject(key, metadata); - if (Object.hasOwn(lastObject, lastKey) && logWarnings) - console.warn(getConsoleLogPrefix() + `Metadata key "${key}" is already occupied by a value. Overwriting it.`); - - lastObject[lastKey] = value; - } - - return { - version: headerReadResult.version, - tensorCount: headerReadResult.tensorCount, - metadata: metadata as GgufMetadata, - tensorInfo: tensorReadResult?.tensorInfo, - metadataSize: headerReadResult.metadataSize, - tensorInfoSize: tensorReadResult?.tensorInfoSize - }; - } - - private async _readGgufValue(type: GgufValueType, offset: number | GgufReadOffset): Promise { - const readOffset = GgufReadOffset.resolveReadOffset(offset); - - switch (type) { - case GgufValueType.Uint8: return await this._fileReader.readUint8(readOffset); - case GgufValueType.Int8: return await this._fileReader.readInt8(readOffset); - case GgufValueType.Uint16: return await this._fileReader.readUint16(readOffset); - case GgufValueType.Int16: return await this._fileReader.readInt16(readOffset); - case GgufValueType.Uint32: return await this._fileReader.readUint32(readOffset); - case GgufValueType.Int32: return await this._fileReader.readInt32(readOffset); - case GgufValueType.Float32: return await this._fileReader.readFloat32(readOffset); - case GgufValueType.Bool: return await this._fileReader.readBool(readOffset); - case GgufValueType.String: return await this._fileReader.readString(readOffset); - case GgufValueType.Uint64: return await this._fileReader.readUint64(readOffset); - case GgufValueType.Int64: return await this._fileReader.readInt64(readOffset); - case GgufValueType.Float64: return await this._fileReader.readFloat64(readOffset); - } - - if (type === GgufValueType.Array) { - const arrayType = await this._fileReader.readUint32(readOffset); - const arrayLength = await this._fileReader.readUint64(readOffset); - - const arrayValues: any[] = []; - for (let i = 0; i < arrayLength; i++) { - const value = await this._readGgufValue(arrayType, readOffset); - arrayValues.push(value); - } - return arrayValues; - } - - throw new UnsupportedGgufValueTypeError(type); - } - - private async _parseHeaderRaw(readOffset: GgufReadOffset) { - const initialOffset = readOffset.offset; - const fileMagicBytes = await this._fileReader.readByteRange(readOffset, valueTypeToBytesToRead.uint8 * ggufMagic.length); - const fileMagicText = String.fromCharCode(...fileMagicBytes); - - if (fileMagicText !== ggufMagic) - throw new InvalidGgufMagicError(ggufMagic, fileMagicText); - - const version = await this._fileReader.readUint32(readOffset); - const tensorCount = await this._fileReader.readUint64(readOffset); - const metadataKVCount = Number(await this._fileReader.readUint64(readOffset)); - - const metadata: Record = {}; - - for (let i = 0; i < metadataKVCount; i++) { - const keyResult = await this._fileReader.readString(readOffset); - const valueType = await this._fileReader.readUint32(readOffset); - metadata[keyResult] = await this._readGgufValue(valueType, readOffset); - } - - return { - version, - tensorCount: GgufFileReader.castNumber(tensorCount), - metadata: metadata, - metadataSize: readOffset.offset - initialOffset - }; - } - - private async _parseTensorInfo(tensorCount: number | bigint, readOffset: GgufReadOffset) { - const initialOffset = readOffset.offset; - const tensorInfo: GgufTensorInfo[] = []; - - for (let i = 0n; i < BigInt(tensorCount); i++) { - const name = await this._fileReader.readString(readOffset); - const dimensionsNumber = await this._fileReader.readUint32(readOffset); - const dimensions: (number | bigint)[] = []; - - for (let i = 0; i < dimensionsNumber; i++) { - const dimension = await this._fileReader.readUint64(readOffset); - dimensions.push(GgufFileReader.castNumber(dimension)); - } - - const ggmlType = await this._fileReader.readUint32(readOffset); - const offset = await this._fileReader.readUint64(readOffset); - - tensorInfo.push({ - name, - dimensions, - ggmlType: ggmlType as GgmlType, - offset: GgufFileReader.castNumber(offset) - }); - } - - return { - tensorInfo, - tensorInfoSize: readOffset.offset - initialOffset - }; - } - - private static _getNestedObject(key: string, currentNestedObject: any) { - const nestedKey = key.split("."); - const lastKey = nestedKey.pop()!; - - while (nestedKey.length > 0) { - const currentKey = nestedKey.shift()!; - if (!Object.hasOwn(currentNestedObject, currentKey)) - currentNestedObject[currentKey] = {}; - else { - const value = currentNestedObject[currentKey]; - if (value instanceof Array || value == null || typeof value !== "object") - throw new Error( - `Cannot create nested object for key "${key}". The key "${currentKey}" is already occupied by a non-object value.` - ); - } - - currentNestedObject = currentNestedObject[currentKey]; - } - - return { - lastObject: currentNestedObject, - lastKey - }; - } -} diff --git a/src/gguf/ggufParser/fileReaders/GgufFileReader.ts b/src/gguf/ggufParser/fileReaders/GgufFileReader.ts index 7f45cf22..f4dddd91 100644 --- a/src/gguf/ggufParser/fileReaders/GgufFileReader.ts +++ b/src/gguf/ggufParser/fileReaders/GgufFileReader.ts @@ -74,16 +74,6 @@ export abstract class GgufFileReader { return response.readUInt8() === 1; } - public async readString(offset: number | GgufReadOffset) { - const readOffset = GgufReadOffset.resolveReadOffset(offset); - const length = Number(await this.readUint64(readOffset)); - - const readLength = valueTypeToBytesToRead.uint8 * length; - const stringBytes = await this.readByteRange(readOffset, readLength); - - return String.fromCharCode(...stringBytes); - } - protected _addToBuffer(buffer: Buffer){ this._buffer = Buffer.concat([this._buffer, buffer]); } @@ -97,8 +87,10 @@ export abstract class GgufFileReader { return response; } - public static castNumber(value: bigint) { - if (value > Number.MAX_SAFE_INTEGER) return value; + public static castNumberIfSafe(value: bigint) { + if (value > Number.MAX_SAFE_INTEGER) + return value; + return Number(value); } } diff --git a/src/gguf/ggufParser/parseGguf.ts b/src/gguf/ggufParser/parseGguf.ts new file mode 100644 index 00000000..ecf836d9 --- /dev/null +++ b/src/gguf/ggufParser/parseGguf.ts @@ -0,0 +1,79 @@ +import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; +import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; +import {UnsupportedError} from "../../utils/UnsupportedError.js"; +import {GgufReadOffset} from "./utils/GgufReadOffset.js"; +import {GgufFileReader, valueTypeToBytesToRead} from "./fileReaders/GgufFileReader.js"; +import {GgufFileInfo, GgufVersionParserOptions, GgufVersionParserResult} from "./types/GgufFileInfoTypes.js"; +import {GgufV2Parser} from "./parsers/GgufV2Parser.js"; + +const ggufMagic = "GGUF"; + +export async function parseGguf({ + fileReader, + readTensorInfo = true, + ignoreKeys = [], + logWarnings = true +}: { + fileReader: GgufFileReader, + readTensorInfo?: boolean, + ignoreKeys?: string[], + logWarnings?: boolean +}): Promise { + const readOffset = new GgufReadOffset(0); + const magicAndVersion = await parseMagicAndVersion(fileReader, readOffset); + const gguifInfo = await parseGgufUsingASpecificVersionParser({ + fileReader, + readTensorInfo, + ignoreKeys, + + version: magicAndVersion.version, + readOffset, + logWarnings + }); + + return { + version: magicAndVersion.version, + tensorCount: gguifInfo.tensorCount, + metadata: gguifInfo.metadata, + tensorInfo: gguifInfo.tensorInfo, + metadataSize: gguifInfo.metadataSize, + tensorInfoSize: gguifInfo.tensorInfoSize + }; +} + +async function parseMagicAndVersion(fileReader: GgufFileReader, readOffset: GgufReadOffset) { + const fileMagicBytes = await fileReader.readByteRange(readOffset, valueTypeToBytesToRead.uint8 * ggufMagic.length); + const fileMagicText = String.fromCharCode(...fileMagicBytes); + + if (fileMagicText !== ggufMagic) + throw new InvalidGgufMagicError(ggufMagic, fileMagicText); + + const version = await fileReader.readUint32(readOffset); + + return { + magic: ggufMagic, + version + }; +} + +async function parseGgufUsingASpecificVersionParser( + specificVersionParserOptions: GgufVersionParserOptions +): Promise { + switch (specificVersionParserOptions.version) { + case 1: + throw new UnsupportedError("GGUF version 1 is not supported by llama.cpp anymore"); + + case 2: + case 3: + return await (new GgufV2Parser(specificVersionParserOptions)).parse(); + + default: + if (specificVersionParserOptions.logWarnings) + console.warn( + getConsoleLogPrefix() + + `Unsupported GGUF version: ${specificVersionParserOptions.version}. Parsing it as version 3` + ); + + return await (new GgufV2Parser(specificVersionParserOptions)).parse(); + } +} diff --git a/src/gguf/ggufParser/parsers/GgufV2Parser.ts b/src/gguf/ggufParser/parsers/GgufV2Parser.ts new file mode 100644 index 00000000..6eedac92 --- /dev/null +++ b/src/gguf/ggufParser/parsers/GgufV2Parser.ts @@ -0,0 +1,142 @@ +import {GgufFileReader, valueTypeToBytesToRead} from "../fileReaders/GgufFileReader.js"; +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import {UnsupportedGgufValueTypeError} from "../../errors/UnsupportedGgufValueTypeError.js"; +import { + GgufValueType, GgufVersionParserOptions, GgufVersionParserResult, MetadataKeyValueRecord, MetadataValue +} from "../types/GgufFileInfoTypes.js"; +import {GgufMetadata} from "../types/GgufMetadataTypes.js"; +import {GgmlType, GgufTensorInfo} from "../types/GgufTensorInfoTypes.js"; +import {convertMetadataKeyValueRecordToNestedObject} from "../utils/convertMetadataKeyValueRecordToNestedObject.js"; + +export class GgufV2Parser { + private readonly _fileReader: GgufFileReader; + private readonly _readTensorInfo: boolean; + private readonly _ignoreKeys: string[]; + private readonly _readOffset: GgufReadOffset; + private readonly _logWarnings: boolean; + + public constructor({fileReader, readTensorInfo = true, ignoreKeys = [], readOffset, logWarnings}: GgufVersionParserOptions) { + this._fileReader = fileReader; + this._readTensorInfo = readTensorInfo; + this._ignoreKeys = ignoreKeys; + this._readOffset = readOffset; + this._logWarnings = logWarnings; + } + + public async parse(): Promise { + const readOffset = this._readOffset; + const initialOffset = readOffset.offset; + + const headerReadResult = await this._readRawHeader(readOffset); + const tensorReadResult = this._readTensorInfo + ? await this._parseTensorInfo(headerReadResult.tensorCount, readOffset) + : null; + const metadata = convertMetadataKeyValueRecordToNestedObject(headerReadResult.metadata, { + logOverrideWarnings: this._logWarnings, + ignoreKeys: this._ignoreKeys + }); + + return { + tensorCount: headerReadResult.tensorCount, + metadata: metadata as any as GgufMetadata, + tensorInfo: tensorReadResult?.tensorInfo, + metadataSize: headerReadResult.headerSize + initialOffset, + tensorInfoSize: tensorReadResult?.tensorInfoSize + }; + } + + protected async _readGgufValue(type: GgufValueType, offset: number | GgufReadOffset): Promise { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + + switch (type) { + case GgufValueType.Uint8: return await this._fileReader.readUint8(readOffset); + case GgufValueType.Int8: return await this._fileReader.readInt8(readOffset); + case GgufValueType.Uint16: return await this._fileReader.readUint16(readOffset); + case GgufValueType.Int16: return await this._fileReader.readInt16(readOffset); + case GgufValueType.Uint32: return await this._fileReader.readUint32(readOffset); + case GgufValueType.Int32: return await this._fileReader.readInt32(readOffset); + case GgufValueType.Float32: return await this._fileReader.readFloat32(readOffset); + case GgufValueType.Bool: return await this._fileReader.readBool(readOffset); + case GgufValueType.String: return await this._readStringValue(readOffset); + case GgufValueType.Uint64: return await this._fileReader.readUint64(readOffset); + case GgufValueType.Int64: return await this._fileReader.readInt64(readOffset); + case GgufValueType.Float64: return await this._fileReader.readFloat64(readOffset); + } + + if (type === GgufValueType.Array) { + const arrayType = await this._fileReader.readUint32(readOffset); + const arrayLength = await this._fileReader.readUint64(readOffset); + + const arrayValues: MetadataValue[] = []; + for (let i = 0; i < arrayLength; i++) { + const value = await this._readGgufValue(arrayType, readOffset); + arrayValues.push(value); + } + return arrayValues; + } + + throw new UnsupportedGgufValueTypeError(type); + } + + protected async _readStringValue(offset: number | GgufReadOffset) { + const readOffset = GgufReadOffset.resolveReadOffset(offset); + const length = Number(await this._fileReader.readUint64(readOffset)); + + const readLength = valueTypeToBytesToRead.uint8 * length; + const stringBytes = await this._fileReader.readByteRange(readOffset, readLength); + + return String.fromCharCode(...stringBytes); + } + + protected async _readRawHeader(readOffset: GgufReadOffset) { + const initialOffset = readOffset.offset; + + const tensorCount = await this._fileReader.readUint64(readOffset); + const metadataKVCount = Number(await this._fileReader.readUint64(readOffset)); + + const metadata: MetadataKeyValueRecord = {}; + + for (let i = 0; i < metadataKVCount; i++) { + const keyResult = await this._readStringValue(readOffset); + const valueType = await this._fileReader.readUint32(readOffset); + metadata[keyResult] = await this._readGgufValue(valueType, readOffset); + } + + return { + tensorCount: GgufFileReader.castNumberIfSafe(tensorCount), + metadata: metadata, + headerSize: readOffset.offset - initialOffset + }; + } + + private async _parseTensorInfo(tensorCount: number | bigint, readOffset: GgufReadOffset) { + const initialOffset = readOffset.offset; + const tensorInfo: GgufTensorInfo[] = []; + + for (let i = 0n; i < BigInt(tensorCount); i++) { + const name = await this._readStringValue(readOffset); + const dimensionsNumber = await this._fileReader.readUint32(readOffset); + const dimensions: (number | bigint)[] = []; + + for (let i = 0; i < dimensionsNumber; i++) { + const dimension = await this._fileReader.readUint64(readOffset); + dimensions.push(GgufFileReader.castNumberIfSafe(dimension)); + } + + const ggmlType = await this._fileReader.readUint32(readOffset); + const offset = await this._fileReader.readUint64(readOffset); + + tensorInfo.push({ + name, + dimensions, + ggmlType: ggmlType as GgmlType, + offset: GgufFileReader.castNumberIfSafe(offset) + }); + } + + return { + tensorInfo, + tensorInfoSize: readOffset.offset - initialOffset + }; + } +} diff --git a/src/gguf/ggufParser/types/GgufFileInfoTypes.ts b/src/gguf/ggufParser/types/GgufFileInfoTypes.ts index 129722bd..f1a8e789 100644 --- a/src/gguf/ggufParser/types/GgufFileInfoTypes.ts +++ b/src/gguf/ggufParser/types/GgufFileInfoTypes.ts @@ -1,8 +1,16 @@ +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import {GgufFileReader} from "../fileReaders/GgufFileReader.js"; import {GgufMetadata} from "./GgufMetadataTypes.js"; import {GgufTensorInfo} from "./GgufTensorInfoTypes.js"; +export type MetadataValue = string | number | bigint | boolean | MetadataValue[]; +export type MetadataKeyValueRecord = Record; +export type MetadataNestedObject = { + [key: string]: MetadataValue | MetadataNestedObject +}; + export type GgufFileInfo = { - version: 3 | number, + version: 2 | 3 | number, tensorCount: number | bigint, metadata: GgufMetadata, metadataSize: number, @@ -13,3 +21,39 @@ export type GgufFileInfo = { /** can be null if `readTensorInfo` is set to `false` */ tensorInfoSize?: number }; + + +// source: `enum gguf_type` in `ggml.h` in the `llama.cpp` source code +export const enum GgufValueType { + Uint8 = 0, + Int8 = 1, + Uint16 = 2, + Int16 = 3, + Uint32 = 4, + Int32 = 5, + Float32 = 6, + Bool = 7, + String = 8, + Array = 9, + Uint64 = 10, + Int64 = 11, + Float64 = 12 +} + +export type GgufVersionParserOptions = { + fileReader: GgufFileReader, + readTensorInfo?: boolean, + ignoreKeys?: string[], + + version: number, + readOffset: GgufReadOffset, + logWarnings: boolean +}; + +export type GgufVersionParserResult = { + tensorCount: number | bigint, + metadata: GgufMetadata, + tensorInfo?: GgufTensorInfo[], + metadataSize: number, + tensorInfoSize?: number +}; diff --git a/src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts b/src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts new file mode 100644 index 00000000..51b731fe --- /dev/null +++ b/src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts @@ -0,0 +1,59 @@ +import {getConsoleLogPrefix} from "../../../utils/getConsoleLogPrefix.js"; +import {MetadataKeyValueRecord, MetadataNestedObject, MetadataValue} from "../types/GgufFileInfoTypes.js"; + +export function convertMetadataKeyValueRecordToNestedObject( + keyValueRecord: MetadataKeyValueRecord, + { + logOverrideWarnings = true, + ignoreKeys = [] + }: { + logOverrideWarnings?: boolean, + ignoreKeys?: string[] + } = {} +) { + const nestedObject: Record = {}; + const ignoreKeySet = new Set(ignoreKeys); + + for (const [key, value] of Object.entries(keyValueRecord)) { + if (ignoreKeySet.has(key)) + continue; + + const {lastObject, lastKey} = getNestedObject(key, nestedObject); + if (Object.hasOwn(lastObject, lastKey) && logOverrideWarnings) + console.warn(getConsoleLogPrefix() + `Metadata key "${key}" is already occupied by a value. Overwriting it.`); + + lastObject[lastKey] = value; + } + + return nestedObject; +} + +function getNestedObject(key: string, nestedObject: MetadataNestedObject) { + const nestedKey = key.split("."); + const lastKey = nestedKey.pop()!; + + let currentObject = nestedObject; + + while (nestedKey.length > 0) { + const currentKey = nestedKey.shift()!; + if (!Object.hasOwn(currentObject, currentKey)) { + const nextCurrentObject = {}; + currentObject[currentKey] = nextCurrentObject; + + currentObject = nextCurrentObject; + } else { + const value = currentObject[currentKey]; + if (value instanceof Array || value == null || typeof value !== "object") + throw new Error( + `Cannot create nested object for key "${key}". The key "${currentKey}" is already occupied by a non-object value.` + ); + + currentObject = value; + } + } + + return { + lastObject: currentObject, + lastKey + }; +} diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index 68628895..381bbe12 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,6 +1,6 @@ import {describe, expect, it, test} from "vitest"; import {GgufFsFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFsFileReader.js"; -import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; +import {parseGguf} from "../../../src/gguf/ggufParser/parseGguf.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; @@ -20,11 +20,10 @@ describe("GGUF Parser", async () => { it("should parse local gguf model", async () => { const fileReader = new GgufFsFileReader({filePath: modelPath}); - const ggufParser = new GgufParser({ + + const metadata = await parseGguf({ fileReader: fileReader }); - - const metadata = await ggufParser.parseFileInfo(); expect(metadata.tensorInfo!.length).to.be.eql(Number(metadata.tensorCount)); expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); @@ -32,9 +31,10 @@ describe("GGUF Parser", async () => { it("should calculate GGUF VRAM Usage", async () => { const fileReader = new GgufFsFileReader({filePath: modelPath}); - const ggufParser = new GgufParser({fileReader: fileReader}); - const metadata = await ggufParser.parseFileInfo(); + const metadata = await parseGguf({ + fileReader: fileReader + }); const ggufInsights = new GgufInsights(metadata); diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 724cb0cd..708b7d28 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,5 +1,5 @@ import {describe, expect, it, test} from "vitest"; -import {GgufParser} from "../../../src/gguf/ggufParser/GgufParser.js"; +import {parseGguf} from "../../../src/gguf/ggufParser/parseGguf.js"; import {GgufNetworkFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; @@ -17,24 +17,22 @@ describe("GGUF Parser", async () => { it("should parse remote gguf model", async () => { const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - const ggufParser = new GgufParser({ + + const metadata = await parseGguf({ fileReader: fileReader }); - const metadata = await ggufParser.parseFileInfo(); - expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); }); it("should parse remote gguf model without tensor info", async () => { const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - const ggufParser = new GgufParser({ + + const metadata = await parseGguf({ fileReader: fileReader, readTensorInfo: false }); - const metadata = await ggufParser.parseFileInfo(); - expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); }); }); From 98e911dfd9881261fa3c9350e45c970773e37115 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 18:09:01 +0200 Subject: [PATCH 16/52] refactor: rename `ggufParser` directory to `parser` --- src/cli/commands/inspect/commands/InspectGgufCommand.ts | 2 +- src/gguf/GgufInsights.ts | 4 ++-- src/gguf/getGgufFileInfo.ts | 6 +++--- .../{ggufParser => parser}/fileReaders/GgufFileReader.ts | 0 .../{ggufParser => parser}/fileReaders/GgufFsFileReader.ts | 0 .../fileReaders/GgufNetworkFetchFileReader.ts | 0 src/gguf/{ggufParser => parser}/parseGguf.ts | 0 src/gguf/{ggufParser => parser}/parsers/GgufV2Parser.ts | 0 src/gguf/{ggufParser => parser}/types/GgufFileInfoTypes.ts | 0 src/gguf/{ggufParser => parser}/types/GgufMetadataTypes.ts | 0 .../{ggufParser => parser}/types/GgufTensorInfoTypes.ts | 0 src/gguf/{ggufParser => parser}/utils/GgufReadOffset.ts | 0 .../utils/convertMetadataKeyValueRecordToNestedObject.ts | 0 .../{ggufParser => parser}/utils/getGgufFileTypeName.ts | 0 .../{ggufParser => parser}/utils/getGgufMetadataLlmData.ts | 0 test/modelDependent/functionary/gguf.test.ts | 4 ++-- test/standalone/gguf/gguf.test.ts | 4 ++-- test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts | 2 +- 18 files changed, 11 insertions(+), 11 deletions(-) rename src/gguf/{ggufParser => parser}/fileReaders/GgufFileReader.ts (100%) rename src/gguf/{ggufParser => parser}/fileReaders/GgufFsFileReader.ts (100%) rename src/gguf/{ggufParser => parser}/fileReaders/GgufNetworkFetchFileReader.ts (100%) rename src/gguf/{ggufParser => parser}/parseGguf.ts (100%) rename src/gguf/{ggufParser => parser}/parsers/GgufV2Parser.ts (100%) rename src/gguf/{ggufParser => parser}/types/GgufFileInfoTypes.ts (100%) rename src/gguf/{ggufParser => parser}/types/GgufMetadataTypes.ts (100%) rename src/gguf/{ggufParser => parser}/types/GgufTensorInfoTypes.ts (100%) rename src/gguf/{ggufParser => parser}/utils/GgufReadOffset.ts (100%) rename src/gguf/{ggufParser => parser}/utils/convertMetadataKeyValueRecordToNestedObject.ts (100%) rename src/gguf/{ggufParser => parser}/utils/getGgufFileTypeName.ts (100%) rename src/gguf/{ggufParser => parser}/utils/getGgufMetadataLlmData.ts (100%) diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index 92e362b8..c64c7e9d 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -4,7 +4,7 @@ import chalk from "chalk"; import bytes from "bytes"; import {getGgufFileInfo} from "../../../../gguf/getGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; -import {getGgufFileTypeName} from "../../../../gguf/ggufParser/utils/getGgufFileTypeName.js"; +import {getGgufFileTypeName} from "../../../../gguf/parser/utils/getGgufFileTypeName.js"; type InspectGgufCommand = { path: string, diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 68e399d3..45664aa1 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,5 +1,5 @@ -import {getGgufMetadataLlmData} from "./ggufParser/utils/getGgufMetadataLlmData.js"; -import {GgufMetadata} from "./ggufParser/types/GgufMetadataTypes.js"; +import {getGgufMetadataLlmData} from "./parser/utils/getGgufMetadataLlmData.js"; +import {GgufMetadata} from "./parser/types/GgufMetadataTypes.js"; export class GgufInsights { public readonly metadata: GgufMetadata; diff --git a/src/gguf/getGgufFileInfo.ts b/src/gguf/getGgufFileInfo.ts index be964666..a629d907 100644 --- a/src/gguf/getGgufFileInfo.ts +++ b/src/gguf/getGgufFileInfo.ts @@ -1,7 +1,7 @@ import retry from "async-retry"; -import {parseGguf} from "./ggufParser/parseGguf.js"; -import {GgufNetworkFetchFileReader} from "./ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; -import {GgufFsFileReader} from "./ggufParser/fileReaders/GgufFsFileReader.js"; +import {parseGguf} from "./parser/parseGguf.js"; +import {GgufNetworkFetchFileReader} from "./parser/fileReaders/GgufNetworkFetchFileReader.js"; +import {GgufFsFileReader} from "./parser/fileReaders/GgufFsFileReader.js"; import {ggufDefaultRetryOptions} from "./consts.js"; diff --git a/src/gguf/ggufParser/fileReaders/GgufFileReader.ts b/src/gguf/parser/fileReaders/GgufFileReader.ts similarity index 100% rename from src/gguf/ggufParser/fileReaders/GgufFileReader.ts rename to src/gguf/parser/fileReaders/GgufFileReader.ts diff --git a/src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts b/src/gguf/parser/fileReaders/GgufFsFileReader.ts similarity index 100% rename from src/gguf/ggufParser/fileReaders/GgufFsFileReader.ts rename to src/gguf/parser/fileReaders/GgufFsFileReader.ts diff --git a/src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.ts b/src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.ts similarity index 100% rename from src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.ts rename to src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.ts diff --git a/src/gguf/ggufParser/parseGguf.ts b/src/gguf/parser/parseGguf.ts similarity index 100% rename from src/gguf/ggufParser/parseGguf.ts rename to src/gguf/parser/parseGguf.ts diff --git a/src/gguf/ggufParser/parsers/GgufV2Parser.ts b/src/gguf/parser/parsers/GgufV2Parser.ts similarity index 100% rename from src/gguf/ggufParser/parsers/GgufV2Parser.ts rename to src/gguf/parser/parsers/GgufV2Parser.ts diff --git a/src/gguf/ggufParser/types/GgufFileInfoTypes.ts b/src/gguf/parser/types/GgufFileInfoTypes.ts similarity index 100% rename from src/gguf/ggufParser/types/GgufFileInfoTypes.ts rename to src/gguf/parser/types/GgufFileInfoTypes.ts diff --git a/src/gguf/ggufParser/types/GgufMetadataTypes.ts b/src/gguf/parser/types/GgufMetadataTypes.ts similarity index 100% rename from src/gguf/ggufParser/types/GgufMetadataTypes.ts rename to src/gguf/parser/types/GgufMetadataTypes.ts diff --git a/src/gguf/ggufParser/types/GgufTensorInfoTypes.ts b/src/gguf/parser/types/GgufTensorInfoTypes.ts similarity index 100% rename from src/gguf/ggufParser/types/GgufTensorInfoTypes.ts rename to src/gguf/parser/types/GgufTensorInfoTypes.ts diff --git a/src/gguf/ggufParser/utils/GgufReadOffset.ts b/src/gguf/parser/utils/GgufReadOffset.ts similarity index 100% rename from src/gguf/ggufParser/utils/GgufReadOffset.ts rename to src/gguf/parser/utils/GgufReadOffset.ts diff --git a/src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts b/src/gguf/parser/utils/convertMetadataKeyValueRecordToNestedObject.ts similarity index 100% rename from src/gguf/ggufParser/utils/convertMetadataKeyValueRecordToNestedObject.ts rename to src/gguf/parser/utils/convertMetadataKeyValueRecordToNestedObject.ts diff --git a/src/gguf/ggufParser/utils/getGgufFileTypeName.ts b/src/gguf/parser/utils/getGgufFileTypeName.ts similarity index 100% rename from src/gguf/ggufParser/utils/getGgufFileTypeName.ts rename to src/gguf/parser/utils/getGgufFileTypeName.ts diff --git a/src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts b/src/gguf/parser/utils/getGgufMetadataLlmData.ts similarity index 100% rename from src/gguf/ggufParser/utils/getGgufMetadataLlmData.ts rename to src/gguf/parser/utils/getGgufMetadataLlmData.ts diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index 381bbe12..9db89c72 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,6 +1,6 @@ import {describe, expect, it, test} from "vitest"; -import {GgufFsFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufFsFileReader.js"; -import {parseGguf} from "../../../src/gguf/ggufParser/parseGguf.js"; +import {GgufFsFileReader} from "../../../src/gguf/parser/fileReaders/GgufFsFileReader.js"; +import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 708b7d28..29bdc7a5 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,6 +1,6 @@ import {describe, expect, it, test} from "vitest"; -import {parseGguf} from "../../../src/gguf/ggufParser/parseGguf.js"; -import {GgufNetworkFetchFileReader} from "../../../src/gguf/ggufParser/fileReaders/GgufNetworkFetchFileReader.js"; +import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; +import {GgufNetworkFetchFileReader} from "../../../src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.js"; import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; diff --git a/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts index f09299f7..43ab42f6 100644 --- a/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts +++ b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts @@ -1,4 +1,4 @@ -import {GgufFileInfo} from "../../../src/gguf/ggufParser/types/GgufFileInfoTypes.js"; +import {GgufFileInfo} from "../../../src/gguf/parser/types/GgufFileInfoTypes.js"; export function simplifyGgufInfoForTestSnapshot(ggufFileInfo: GgufFileInfo) { const ggufFileInfoCopy = structuredClone(ggufFileInfo); From 550188b10f407ee1adaef587972e33714070169d Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 18:11:02 +0200 Subject: [PATCH 17/52] refactor: move files --- src/cli/commands/inspect/commands/InspectGgufCommand.ts | 2 +- src/gguf/GgufInsights.ts | 4 ++-- src/gguf/{parser => }/fileReaders/GgufFileReader.ts | 0 src/gguf/{parser => }/fileReaders/GgufFsFileReader.ts | 2 +- .../fileReaders/GgufNetworkFetchFileReader.ts | 2 +- src/gguf/getGgufFileInfo.ts | 4 ++-- src/gguf/parser/{parsers => }/GgufV2Parser.ts | 2 +- src/gguf/parser/parseGguf.ts | 8 ++++---- src/gguf/{parser => }/types/GgufFileInfoTypes.ts | 0 src/gguf/{parser => }/types/GgufMetadataTypes.ts | 0 src/gguf/{parser => }/types/GgufTensorInfoTypes.ts | 0 src/gguf/{parser => }/utils/GgufReadOffset.ts | 0 .../utils/convertMetadataKeyValueRecordToNestedObject.ts | 2 +- src/gguf/{parser => }/utils/getGgufFileTypeName.ts | 0 src/gguf/{parser => }/utils/getGgufMetadataLlmData.ts | 2 +- test/modelDependent/functionary/gguf.test.ts | 2 +- test/standalone/gguf/gguf.test.ts | 2 +- test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts | 2 +- 18 files changed, 17 insertions(+), 17 deletions(-) rename src/gguf/{parser => }/fileReaders/GgufFileReader.ts (100%) rename src/gguf/{parser => }/fileReaders/GgufFsFileReader.ts (98%) rename src/gguf/{parser => }/fileReaders/GgufNetworkFetchFileReader.ts (99%) rename src/gguf/parser/{parsers => }/GgufV2Parser.ts (98%) rename src/gguf/{parser => }/types/GgufFileInfoTypes.ts (100%) rename src/gguf/{parser => }/types/GgufMetadataTypes.ts (100%) rename src/gguf/{parser => }/types/GgufTensorInfoTypes.ts (100%) rename src/gguf/{parser => }/utils/GgufReadOffset.ts (100%) rename src/gguf/{parser => }/utils/convertMetadataKeyValueRecordToNestedObject.ts (96%) rename src/gguf/{parser => }/utils/getGgufFileTypeName.ts (100%) rename src/gguf/{parser => }/utils/getGgufMetadataLlmData.ts (84%) diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index c64c7e9d..eca93533 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -4,7 +4,7 @@ import chalk from "chalk"; import bytes from "bytes"; import {getGgufFileInfo} from "../../../../gguf/getGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; -import {getGgufFileTypeName} from "../../../../gguf/parser/utils/getGgufFileTypeName.js"; +import {getGgufFileTypeName} from "../../../../gguf/utils/getGgufFileTypeName.js"; type InspectGgufCommand = { path: string, diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 45664aa1..29e09b2e 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,5 +1,5 @@ -import {getGgufMetadataLlmData} from "./parser/utils/getGgufMetadataLlmData.js"; -import {GgufMetadata} from "./parser/types/GgufMetadataTypes.js"; +import {getGgufMetadataLlmData} from "./utils/getGgufMetadataLlmData.js"; +import {GgufMetadata} from "./types/GgufMetadataTypes.js"; export class GgufInsights { public readonly metadata: GgufMetadata; diff --git a/src/gguf/parser/fileReaders/GgufFileReader.ts b/src/gguf/fileReaders/GgufFileReader.ts similarity index 100% rename from src/gguf/parser/fileReaders/GgufFileReader.ts rename to src/gguf/fileReaders/GgufFileReader.ts diff --git a/src/gguf/parser/fileReaders/GgufFsFileReader.ts b/src/gguf/fileReaders/GgufFsFileReader.ts similarity index 98% rename from src/gguf/parser/fileReaders/GgufFsFileReader.ts rename to src/gguf/fileReaders/GgufFsFileReader.ts index cf253271..5f5113a1 100644 --- a/src/gguf/parser/fileReaders/GgufFsFileReader.ts +++ b/src/gguf/fileReaders/GgufFsFileReader.ts @@ -2,7 +2,7 @@ import fs from "node:fs/promises"; import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../../consts.js"; +import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../consts.js"; import {GgufFileReader} from "./GgufFileReader.js"; type GgufFsFileReaderOptions = { diff --git a/src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.ts b/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts similarity index 99% rename from src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.ts rename to src/gguf/fileReaders/GgufNetworkFetchFileReader.ts index adcc370f..0a2abee6 100644 --- a/src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.ts +++ b/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts @@ -1,7 +1,7 @@ import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../../consts.js"; +import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../consts.js"; import {GgufFileReader} from "./GgufFileReader.js"; type GgufFetchFileReaderOptions = { diff --git a/src/gguf/getGgufFileInfo.ts b/src/gguf/getGgufFileInfo.ts index a629d907..29a81a0d 100644 --- a/src/gguf/getGgufFileInfo.ts +++ b/src/gguf/getGgufFileInfo.ts @@ -1,7 +1,7 @@ import retry from "async-retry"; import {parseGguf} from "./parser/parseGguf.js"; -import {GgufNetworkFetchFileReader} from "./parser/fileReaders/GgufNetworkFetchFileReader.js"; -import {GgufFsFileReader} from "./parser/fileReaders/GgufFsFileReader.js"; +import {GgufNetworkFetchFileReader} from "./fileReaders/GgufNetworkFetchFileReader.js"; +import {GgufFsFileReader} from "./fileReaders/GgufFsFileReader.js"; import {ggufDefaultRetryOptions} from "./consts.js"; diff --git a/src/gguf/parser/parsers/GgufV2Parser.ts b/src/gguf/parser/GgufV2Parser.ts similarity index 98% rename from src/gguf/parser/parsers/GgufV2Parser.ts rename to src/gguf/parser/GgufV2Parser.ts index 6eedac92..d2b80e71 100644 --- a/src/gguf/parser/parsers/GgufV2Parser.ts +++ b/src/gguf/parser/GgufV2Parser.ts @@ -1,6 +1,6 @@ import {GgufFileReader, valueTypeToBytesToRead} from "../fileReaders/GgufFileReader.js"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {UnsupportedGgufValueTypeError} from "../../errors/UnsupportedGgufValueTypeError.js"; +import {UnsupportedGgufValueTypeError} from "../errors/UnsupportedGgufValueTypeError.js"; import { GgufValueType, GgufVersionParserOptions, GgufVersionParserResult, MetadataKeyValueRecord, MetadataValue } from "../types/GgufFileInfoTypes.js"; diff --git a/src/gguf/parser/parseGguf.ts b/src/gguf/parser/parseGguf.ts index ecf836d9..fc225a09 100644 --- a/src/gguf/parser/parseGguf.ts +++ b/src/gguf/parser/parseGguf.ts @@ -1,10 +1,10 @@ import {InvalidGgufMagicError} from "../errors/InvalidGgufMagicError.js"; import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; import {UnsupportedError} from "../../utils/UnsupportedError.js"; -import {GgufReadOffset} from "./utils/GgufReadOffset.js"; -import {GgufFileReader, valueTypeToBytesToRead} from "./fileReaders/GgufFileReader.js"; -import {GgufFileInfo, GgufVersionParserOptions, GgufVersionParserResult} from "./types/GgufFileInfoTypes.js"; -import {GgufV2Parser} from "./parsers/GgufV2Parser.js"; +import {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import {GgufFileReader, valueTypeToBytesToRead} from "../fileReaders/GgufFileReader.js"; +import {GgufFileInfo, GgufVersionParserOptions, GgufVersionParserResult} from "../types/GgufFileInfoTypes.js"; +import {GgufV2Parser} from "./GgufV2Parser.js"; const ggufMagic = "GGUF"; diff --git a/src/gguf/parser/types/GgufFileInfoTypes.ts b/src/gguf/types/GgufFileInfoTypes.ts similarity index 100% rename from src/gguf/parser/types/GgufFileInfoTypes.ts rename to src/gguf/types/GgufFileInfoTypes.ts diff --git a/src/gguf/parser/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts similarity index 100% rename from src/gguf/parser/types/GgufMetadataTypes.ts rename to src/gguf/types/GgufMetadataTypes.ts diff --git a/src/gguf/parser/types/GgufTensorInfoTypes.ts b/src/gguf/types/GgufTensorInfoTypes.ts similarity index 100% rename from src/gguf/parser/types/GgufTensorInfoTypes.ts rename to src/gguf/types/GgufTensorInfoTypes.ts diff --git a/src/gguf/parser/utils/GgufReadOffset.ts b/src/gguf/utils/GgufReadOffset.ts similarity index 100% rename from src/gguf/parser/utils/GgufReadOffset.ts rename to src/gguf/utils/GgufReadOffset.ts diff --git a/src/gguf/parser/utils/convertMetadataKeyValueRecordToNestedObject.ts b/src/gguf/utils/convertMetadataKeyValueRecordToNestedObject.ts similarity index 96% rename from src/gguf/parser/utils/convertMetadataKeyValueRecordToNestedObject.ts rename to src/gguf/utils/convertMetadataKeyValueRecordToNestedObject.ts index 51b731fe..ef0c224e 100644 --- a/src/gguf/parser/utils/convertMetadataKeyValueRecordToNestedObject.ts +++ b/src/gguf/utils/convertMetadataKeyValueRecordToNestedObject.ts @@ -1,4 +1,4 @@ -import {getConsoleLogPrefix} from "../../../utils/getConsoleLogPrefix.js"; +import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; import {MetadataKeyValueRecord, MetadataNestedObject, MetadataValue} from "../types/GgufFileInfoTypes.js"; export function convertMetadataKeyValueRecordToNestedObject( diff --git a/src/gguf/parser/utils/getGgufFileTypeName.ts b/src/gguf/utils/getGgufFileTypeName.ts similarity index 100% rename from src/gguf/parser/utils/getGgufFileTypeName.ts rename to src/gguf/utils/getGgufFileTypeName.ts diff --git a/src/gguf/parser/utils/getGgufMetadataLlmData.ts b/src/gguf/utils/getGgufMetadataLlmData.ts similarity index 84% rename from src/gguf/parser/utils/getGgufMetadataLlmData.ts rename to src/gguf/utils/getGgufMetadataLlmData.ts index 118cb104..c9ae20b2 100644 --- a/src/gguf/parser/utils/getGgufMetadataLlmData.ts +++ b/src/gguf/utils/getGgufMetadataLlmData.ts @@ -1,5 +1,5 @@ import {GgufArchitectureType, GgufMetadata} from "../types/GgufMetadataTypes.js"; -import {MergeOptionalUnionTypes} from "../../../utils/mergeUnionTypes.js"; +import {MergeOptionalUnionTypes} from "../../utils/mergeUnionTypes.js"; export function getGgufMetadataLlmData(ggufMetadata: GgufMetadata): ( GgufArchitectureType extends T diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index 9db89c72..3967b844 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -1,5 +1,5 @@ import {describe, expect, it, test} from "vitest"; -import {GgufFsFileReader} from "../../../src/gguf/parser/fileReaders/GgufFsFileReader.js"; +import {GgufFsFileReader} from "../../../src/gguf/fileReaders/GgufFsFileReader.js"; import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts index 29bdc7a5..b6ce2622 100644 --- a/test/standalone/gguf/gguf.test.ts +++ b/test/standalone/gguf/gguf.test.ts @@ -1,6 +1,6 @@ import {describe, expect, it, test} from "vitest"; import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; -import {GgufNetworkFetchFileReader} from "../../../src/gguf/parser/fileReaders/GgufNetworkFetchFileReader.js"; +import {GgufNetworkFetchFileReader} from "../../../src/gguf/fileReaders/GgufNetworkFetchFileReader.js"; import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; diff --git a/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts index 43ab42f6..cebf6e8a 100644 --- a/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts +++ b/test/utils/helpers/simplifyGgufInfoForTestSnapshot.ts @@ -1,4 +1,4 @@ -import {GgufFileInfo} from "../../../src/gguf/parser/types/GgufFileInfoTypes.js"; +import {GgufFileInfo} from "../../../src/gguf/types/GgufFileInfoTypes.js"; export function simplifyGgufInfoForTestSnapshot(ggufFileInfo: GgufFileInfo) { const ggufFileInfoCopy = structuredClone(ggufFileInfo); From 642ccc8d2fb4eaf38ab8e0367ddbbe60f2b1fb54 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 18:15:53 +0200 Subject: [PATCH 18/52] refactor: rename `getGgufFileInfo` --- src/cli/commands/inspect/commands/InspectGgufCommand.ts | 4 ++-- src/gguf/{getGgufFileInfo.ts => readGgufFileInfo.ts} | 5 +++-- test/modelDependent/functionary/gguf.test.ts | 4 ++-- 3 files changed, 7 insertions(+), 6 deletions(-) rename src/gguf/{getGgufFileInfo.ts => readGgufFileInfo.ts} (85%) diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index eca93533..4daf51cf 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -2,7 +2,7 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import bytes from "bytes"; -import {getGgufFileInfo} from "../../../../gguf/getGgufFileInfo.js"; +import {readGgufFileInfo} from "../../../../gguf/readGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; import {getGgufFileTypeName} from "../../../../gguf/utils/getGgufFileTypeName.js"; @@ -39,7 +39,7 @@ export const InspectGgufCommand: CommandModule = { else console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); - const parsedMetadata = await getGgufFileInfo(ggufPath, {ignoreKeys: []}); + const parsedMetadata = await readGgufFileInfo(ggufPath, {ignoreKeys: []}); const fileTypeName = getGgufFileTypeName(parsedMetadata.metadata.general?.file_type); const metadataPrettyPrintOptions: PrettyPrintObjectOptions = { maxArrayValues: 10, diff --git a/src/gguf/getGgufFileInfo.ts b/src/gguf/readGgufFileInfo.ts similarity index 85% rename from src/gguf/getGgufFileInfo.ts rename to src/gguf/readGgufFileInfo.ts index 29a81a0d..e39d93eb 100644 --- a/src/gguf/getGgufFileInfo.ts +++ b/src/gguf/readGgufFileInfo.ts @@ -6,9 +6,10 @@ import {ggufDefaultRetryOptions} from "./consts.js"; /** - * Parse a GGUF file and return its metadata and tensor info (unless `readTensorInfo` is set to `false`) + * Read a GGUF file and return its metadata and tensor info (unless `readTensorInfo` is set to `false`). + * Only the parts of the file required for the metadata and tensor info are read. */ -export async function getGgufFileInfo(pathOrUrl: string, { +export async function readGgufFileInfo(pathOrUrl: string, { readTensorInfo = true, sourceType, retryOptions = ggufDefaultRetryOptions, diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts index 3967b844..df7832e4 100644 --- a/test/modelDependent/functionary/gguf.test.ts +++ b/test/modelDependent/functionary/gguf.test.ts @@ -4,7 +4,7 @@ import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; -import {getGgufFileInfo} from "../../../src/gguf/getGgufFileInfo.js"; +import {readGgufFileInfo} from "../../../src/gguf/readGgufFileInfo.js"; import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; describe("GGUF Parser", async () => { @@ -52,7 +52,7 @@ describe("GGUF Parser", async () => { }); it("should fetch GGUF metadata", async () => { - const ggufMetadataParseResult = await getGgufFileInfo(modelPath); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); expect(simplifyGgufInfoForTestSnapshot(ggufMetadataParseResult)).toMatchSnapshot(); From 7bf8a7c937208d150d0edb323dcbef5943466e85 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 20 Mar 2024 18:49:18 +0200 Subject: [PATCH 19/52] feat: add more options to `inspect gguf` command --- .../inspect/commands/InspectGgufCommand.ts | 111 +++++++++++++----- src/gguf/parser/GgufV2Parser.ts | 10 +- src/gguf/parser/GgufV3Parser.ts | 5 + src/gguf/parser/parseGguf.ts | 9 +- 4 files changed, 96 insertions(+), 39 deletions(-) create mode 100644 src/gguf/parser/GgufV3Parser.ts diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index 4daf51cf..dd2362b0 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -1,14 +1,19 @@ import path from "path"; +import process from "process"; import {CommandModule} from "yargs"; import chalk from "chalk"; import bytes from "bytes"; +import fs from "fs-extra"; import {readGgufFileInfo} from "../../../../gguf/readGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; import {getGgufFileTypeName} from "../../../../gguf/utils/getGgufFileTypeName.js"; type InspectGgufCommand = { path: string, - fullTensorInfo: boolean + fullTensorInfo: boolean, + fullMetadataArrays: boolean, + plainJson: boolean, + outputToJsonFile?: string }; export const InspectGgufCommand: CommandModule = { @@ -25,46 +30,90 @@ export const InspectGgufCommand: CommandModule = { alias: "t", type: "boolean", default: false, - description: "Show the full tensor info" + description: "Show the full tensor info", + group: "Optional:" + }) + .option("fullMetadataArrays", { + alias: "m", + type: "boolean", + default: false, + description: "Print the full arrays in the metadata. Caution: those arrays can be extremely large and cover the entire terminal screen. Use with caution.", + group: "Optional:" + }) + .option("plainJson", { + type: "boolean", + default: false, + description: "Print the output as plain JSON with no formatting. Useful for piping the output to other commands. The output won't truncate any values, so it may be extremely large. Use with caution.", + group: "Optional:" + }) + .option("outputToJsonFile", { + type: "string", + description: "Path to a file to write the output to as JSON. The output won't truncate any values. The output won't be printed to the console", + group: "Optional:" }); }, - async handler({path: ggufPath, fullTensorInfo}: InspectGgufCommand) { + async handler({path: ggufPath, fullTensorInfo, fullMetadataArrays, plainJson, outputToJsonFile}: InspectGgufCommand) { const isPathUrl = ggufPath.startsWith("http://") || ggufPath.startsWith("https://"); const resolvedGgufPath = isPathUrl ? ggufPath : path.resolve(ggufPath); - if (isPathUrl) - console.info(`${chalk.yellow("URL:")} ${resolvedGgufPath}`); - else - console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); + if (!plainJson) { + if (isPathUrl) + console.info(`${chalk.yellow("URL:")} ${resolvedGgufPath}`); + else + console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); + } const parsedMetadata = await readGgufFileInfo(ggufPath, {ignoreKeys: []}); const fileTypeName = getGgufFileTypeName(parsedMetadata.metadata.general?.file_type); - const metadataPrettyPrintOptions: PrettyPrintObjectOptions = { - maxArrayValues: 10, - useNumberGrouping: true, - maxArrayItemsWidth: process.stdout.columns - 1 - }; - const tensorInfoPrettyPrintOptions: PrettyPrintObjectOptions = { - maxArrayValues: fullTensorInfo - ? undefined - : 4, - useNumberGrouping: true, - maxArrayItemsWidth: process.stdout.columns - 1, - multilineObjects: false - }; - const numberLocaleFormattingOptions = { - style: "decimal", - useGrouping: true - } as const; - console.info(`${chalk.yellow("GGUF version:")} ${parsedMetadata.version}`); - console.info(`${chalk.yellow("Tensor count:")} ${parsedMetadata.tensorCount.toLocaleString("en-US", numberLocaleFormattingOptions)}`); - console.info(`${chalk.yellow("Metadata size:")} ${bytes(parsedMetadata.metadataSize)}`); - console.info(`${chalk.yellow("Tensor info size:")} ${bytes(parsedMetadata.tensorInfoSize!)}`); - console.info(`${chalk.yellow("File type:")} ${fileTypeName ?? ""} ${chalk.white(`(${parsedMetadata.metadata.general?.file_type})`)}`); - console.info(`${chalk.yellow("Metadata:")} ${prettyPrintObject(parsedMetadata.metadata, undefined, metadataPrettyPrintOptions)}`); - console.info(`${chalk.yellow("Tensor info:")} ${prettyPrintObject(parsedMetadata.tensorInfo, undefined, tensorInfoPrettyPrintOptions)}`); + if (plainJson || outputToJsonFile != null) { + const outputJson = JSON.stringify({ + version: parsedMetadata.version, + fileType: fileTypeName, + tensorCount: parsedMetadata.tensorCount, + metadataSize: parsedMetadata.metadataSize, + tensorInfoSize: parsedMetadata.tensorInfoSize, + metadata: parsedMetadata.metadata, + tensorInfo: parsedMetadata.tensorInfo + }, undefined, 4); + + if (outputToJsonFile != null) { + const filePath = path.resolve(process.cwd(), outputToJsonFile); + await fs.writeFile(filePath, outputJson, "utf8"); + console.info(`${chalk.yellow("JSON written to file:")} ${filePath}`); + } else { + console.info(outputJson); + } + } else { + const metadataPrettyPrintOptions: PrettyPrintObjectOptions = { + maxArrayValues: fullMetadataArrays + ? undefined + : 10, + useNumberGrouping: true, + maxArrayItemsWidth: process.stdout.columns - 1 + }; + const tensorInfoPrettyPrintOptions: PrettyPrintObjectOptions = { + maxArrayValues: fullTensorInfo + ? undefined + : 4, + useNumberGrouping: true, + maxArrayItemsWidth: process.stdout.columns - 1, + multilineObjects: false + }; + const numberLocaleFormattingOptions = { + style: "decimal", + useGrouping: true + } as const; + + console.info(`${chalk.yellow("GGUF version:")} ${parsedMetadata.version}`); + console.info(`${chalk.yellow("Tensor count:")} ${parsedMetadata.tensorCount.toLocaleString("en-US", numberLocaleFormattingOptions)}`); + console.info(`${chalk.yellow("Metadata size:")} ${bytes(parsedMetadata.metadataSize)}`); + console.info(`${chalk.yellow("Tensor info size:")} ${bytes(parsedMetadata.tensorInfoSize!)}`); + console.info(`${chalk.yellow("File type:")} ${fileTypeName ?? ""} ${chalk.white(`(${parsedMetadata.metadata.general?.file_type})`)}`); + console.info(`${chalk.yellow("Metadata:")} ${prettyPrintObject(parsedMetadata.metadata, undefined, metadataPrettyPrintOptions)}`); + console.info(`${chalk.yellow("Tensor info:")} ${prettyPrintObject(parsedMetadata.tensorInfo, undefined, tensorInfoPrettyPrintOptions)}`); + } } }; diff --git a/src/gguf/parser/GgufV2Parser.ts b/src/gguf/parser/GgufV2Parser.ts index d2b80e71..f6774030 100644 --- a/src/gguf/parser/GgufV2Parser.ts +++ b/src/gguf/parser/GgufV2Parser.ts @@ -10,14 +10,14 @@ import {convertMetadataKeyValueRecordToNestedObject} from "../utils/convertMetad export class GgufV2Parser { private readonly _fileReader: GgufFileReader; - private readonly _readTensorInfo: boolean; + private readonly _shouldReadTensorInfo: boolean; private readonly _ignoreKeys: string[]; private readonly _readOffset: GgufReadOffset; private readonly _logWarnings: boolean; public constructor({fileReader, readTensorInfo = true, ignoreKeys = [], readOffset, logWarnings}: GgufVersionParserOptions) { this._fileReader = fileReader; - this._readTensorInfo = readTensorInfo; + this._shouldReadTensorInfo = readTensorInfo; this._ignoreKeys = ignoreKeys; this._readOffset = readOffset; this._logWarnings = logWarnings; @@ -28,8 +28,8 @@ export class GgufV2Parser { const initialOffset = readOffset.offset; const headerReadResult = await this._readRawHeader(readOffset); - const tensorReadResult = this._readTensorInfo - ? await this._parseTensorInfo(headerReadResult.tensorCount, readOffset) + const tensorReadResult = this._shouldReadTensorInfo + ? await this._readTensorInfo(headerReadResult.tensorCount, readOffset) : null; const metadata = convertMetadataKeyValueRecordToNestedObject(headerReadResult.metadata, { logOverrideWarnings: this._logWarnings, @@ -109,7 +109,7 @@ export class GgufV2Parser { }; } - private async _parseTensorInfo(tensorCount: number | bigint, readOffset: GgufReadOffset) { + private async _readTensorInfo(tensorCount: number | bigint, readOffset: GgufReadOffset) { const initialOffset = readOffset.offset; const tensorInfo: GgufTensorInfo[] = []; diff --git a/src/gguf/parser/GgufV3Parser.ts b/src/gguf/parser/GgufV3Parser.ts new file mode 100644 index 00000000..4e4c6eb2 --- /dev/null +++ b/src/gguf/parser/GgufV3Parser.ts @@ -0,0 +1,5 @@ +import {GgufV2Parser} from "./GgufV2Parser.js"; + +export class GgufV3Parser extends GgufV2Parser { + // the implementation is the same as version 2 for now +} diff --git a/src/gguf/parser/parseGguf.ts b/src/gguf/parser/parseGguf.ts index fc225a09..473553f1 100644 --- a/src/gguf/parser/parseGguf.ts +++ b/src/gguf/parser/parseGguf.ts @@ -5,6 +5,7 @@ import {GgufReadOffset} from "../utils/GgufReadOffset.js"; import {GgufFileReader, valueTypeToBytesToRead} from "../fileReaders/GgufFileReader.js"; import {GgufFileInfo, GgufVersionParserOptions, GgufVersionParserResult} from "../types/GgufFileInfoTypes.js"; import {GgufV2Parser} from "./GgufV2Parser.js"; +import {GgufV3Parser} from "./GgufV3Parser.js"; const ggufMagic = "GGUF"; @@ -64,16 +65,18 @@ async function parseGgufUsingASpecificVersionParser( throw new UnsupportedError("GGUF version 1 is not supported by llama.cpp anymore"); case 2: - case 3: return await (new GgufV2Parser(specificVersionParserOptions)).parse(); + case 3: + return await (new GgufV3Parser(specificVersionParserOptions)).parse(); + default: if (specificVersionParserOptions.logWarnings) console.warn( getConsoleLogPrefix() + - `Unsupported GGUF version: ${specificVersionParserOptions.version}. Parsing it as version 3` + `Unsupported GGUF version "${specificVersionParserOptions.version}". Reading the file as GGUF version 3` ); - return await (new GgufV2Parser(specificVersionParserOptions)).parse(); + return await (new GgufV3Parser(specificVersionParserOptions)).parse(); } } From ddbd29ef4f923c6ae7bbcba17ffe0c3e4c662600 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 21 Mar 2024 00:38:34 +0200 Subject: [PATCH 20/52] feat: calculate model VRAM usage based on header tensor info --- llama/addon.cpp | 54 +++++ src/bindings/AddonTypes.ts | 11 +- src/bindings/Llama.ts | 27 ++- src/bindings/utils/getLlamaWithoutBackend.ts | 31 +++ src/cli/commands/ChatCommand.ts | 2 + src/evaluator/LlamaModel.ts | 18 +- src/gguf/GgufInsights.ts | 209 ++++++++++++++---- test/modelDependent/functionary/gguf.test.ts | 62 ------ .../__snapshots__/ggufParser.test.ts.snap} | 4 +- .../functionary/gguf/ggufInsights.test.ts | 81 +++++++ .../functionary/gguf/ggufParser.test.ts | 37 ++++ ...snap => ggufStandaloneParser.test.ts.snap} | 4 +- test/standalone/gguf/gguf.test.ts | 38 ---- .../gguf/ggufStandaloneParser.test.ts | 43 ++++ test/utils/modelFiles.ts | 25 ++- 15 files changed, 492 insertions(+), 154 deletions(-) create mode 100644 src/bindings/utils/getLlamaWithoutBackend.ts delete mode 100644 test/modelDependent/functionary/gguf.test.ts rename test/modelDependent/functionary/{__snapshots__/gguf.test.ts.snap => gguf/__snapshots__/ggufParser.test.ts.snap} (98%) create mode 100644 test/modelDependent/functionary/gguf/ggufInsights.test.ts create mode 100644 test/modelDependent/functionary/gguf/ggufParser.test.ts rename test/standalone/gguf/__snapshots__/{gguf.test.ts.snap => ggufStandaloneParser.test.ts.snap} (96%) delete mode 100644 test/standalone/gguf/gguf.test.ts create mode 100644 test/standalone/gguf/ggufStandaloneParser.test.ts diff --git a/llama/addon.cpp b/llama/addon.cpp index 73b0bf22..4db22869 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -515,6 +515,10 @@ class AddonModel : public Napi::ObjectWrap { return Napi::Boolean::New(info.Env(), shouldPrependBos); } + Napi::Value GetModelSize(const Napi::CallbackInfo& info) { + return Napi::Number::From(info.Env(), llama_model_size(model)); + } + static void init(Napi::Object exports) { exports.Set( "AddonModel", @@ -541,6 +545,7 @@ class AddonModel : public Napi::ObjectWrap { InstanceMethod("getTokenString", &AddonModel::GetTokenString), InstanceMethod("getTokenType", &AddonModel::GetTokenType), InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken), + InstanceMethod("getModelSize", &AddonModel::GetModelSize), InstanceMethod("dispose", &AddonModel::Dispose), } ) @@ -1444,6 +1449,49 @@ Napi::Value systemInfo(const Napi::CallbackInfo& info) { return Napi::String::From(info.Env(), llama_print_system_info()); } +Napi::Value addonGetSupportsGpuOffloading(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_gpu_offload()); +} + +Napi::Value addonGetSupportsMmap(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_mmap()); +} + +Napi::Value addonGetSupportsMlock(const Napi::CallbackInfo& info) { + return Napi::Boolean::New(info.Env(), llama_supports_mlock()); +} + +Napi::Value addonGetBlockSizeForGgmlType(const Napi::CallbackInfo& info) { + const int ggmlType = info[0].As().Int32Value(); + + if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { + return info.Env().Undefined(); + } + + const auto blockSize = ggml_blck_size(static_cast(ggmlType)); + + return Napi::Number::New(info.Env(), blockSize); +} + +Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) { + const int ggmlType = info[0].As().Int32Value(); + + if (ggmlType < 0 || ggmlType > GGML_TYPE_COUNT) { + return info.Env().Undefined(); + } + + const auto typeSize = ggml_type_size(static_cast(ggmlType)); + + return Napi::Number::New(info.Env(), typeSize); +} + +Napi::Value addonGetConsts(const Napi::CallbackInfo& info) { + Napi::Object consts = Napi::Object::New(info.Env()); + consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS)); + + return consts; +} + int addonGetGgmlLogLevelNumber(ggml_log_level level) { switch (level) { case GGML_LOG_LEVEL_ERROR: return 2; @@ -1693,6 +1741,12 @@ static void addonFreeLlamaBackend(Napi::Env env, int* data) { Napi::Object registerCallback(Napi::Env env, Napi::Object exports) { exports.DefineProperties({ Napi::PropertyDescriptor::Function("systemInfo", systemInfo), + Napi::PropertyDescriptor::Function("getSupportsGpuOffloading", addonGetSupportsGpuOffloading), + Napi::PropertyDescriptor::Function("getSupportsMmap", addonGetSupportsMmap), + Napi::PropertyDescriptor::Function("getSupportsMlock", addonGetSupportsMlock), + Napi::PropertyDescriptor::Function("getBlockSizeForGgmlType", addonGetBlockSizeForGgmlType), + Napi::PropertyDescriptor::Function("getTypeSizeForGgmlType", addonGetTypeSizeForGgmlType), + Napi::PropertyDescriptor::Function("getConsts", addonGetConsts), Napi::PropertyDescriptor::Function("setLogger", setLogger), Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel), Napi::PropertyDescriptor::Function("getGpuVramInfo", getGpuVramInfo), diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index 51c9276b..bd28f903 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -33,6 +33,14 @@ export type BindingModule = { new (grammar: AddonGrammar): AddonGrammarEvaluationState }, systemInfo(): string, + getSupportsGpuOffloading(): boolean, + getSupportsMmap(): boolean, + getSupportsMlock(): boolean, + getBlockSizeForGgmlType(ggmlType: number): number | undefined, + getTypeSizeForGgmlType(ggmlType: number): number | undefined, + getConsts(): { + ggmlMaxDims: number + }, setLogger(logger: (level: number, message: string) => void): void, setLoggerLogLevel(level: number): void, getGpuVramInfo(): { @@ -64,7 +72,8 @@ export type AddonModel = { eotToken(): Token, getTokenString(token: number): string, getTokenType(token: Token): number, - shouldPrependBosToken(): boolean + shouldPrependBosToken(): boolean, + getModelSize(): number }; export type AddonContext = { diff --git a/src/bindings/Llama.ts b/src/bindings/Llama.ts index 79726915..42063801 100644 --- a/src/bindings/Llama.ts +++ b/src/bindings/Llama.ts @@ -23,9 +23,13 @@ export class Llama { /** @internal */ public readonly _bindings: BindingModule; /** @internal */ public readonly _backendDisposeGuard = new DisposeGuard(); /** @internal */ public readonly _memoryLock = {}; + /** @internal */ public readonly _consts: ReturnType; /** @internal */ private readonly _gpu: BuildGpu; /** @internal */ private readonly _buildType: "localBuild" | "prebuilt"; /** @internal */ private readonly _cmakeOptions: Readonly>; + /** @internal */ private readonly _supportsGpuOffloading: boolean; + /** @internal */ private readonly _supportsMmap: boolean; + /** @internal */ private readonly _supportsMlock: boolean; /** @internal */ private readonly _llamaCppRelease: { readonly repo: string, readonly release: string @@ -57,6 +61,10 @@ export class Llama { }) { this._bindings = bindings; this._gpu = bindings.getGpuType() ?? false; + this._supportsGpuOffloading = bindings.getSupportsGpuOffloading(); + this._supportsMmap = bindings.getSupportsMmap(); + this._supportsMlock = bindings.getSupportsMlock(); + this._consts = bindings.getConsts(); this._logLevel = logLevel ?? LlamaLogLevel.debug; this._logger = logger; this._buildType = buildType; @@ -100,6 +108,18 @@ export class Llama { return this._gpu; } + public get supportsGpuOffloading() { + return this._supportsGpuOffloading; + } + + public get supportsMmap() { + return this._supportsMmap; + } + + public get supportsMlock() { + return this._supportsMlock; + } + public get logLevel() { return this._logLevel; } @@ -170,6 +190,11 @@ export class Llama { }); } + /** @internal */ + public async _init() { + await this._bindings.init(); + } + /** @internal */ private _onAddonLog(level: number, message: string) { const llamaLogLevel = addonLogLevelToLlamaLogLevel.get(level) ?? LlamaLogLevel.fatal; @@ -280,7 +305,7 @@ export class Llama { }); if (!skipLlamaInit) - await llama._bindings.init(); + await llama._init(); return llama; } diff --git a/src/bindings/utils/getLlamaWithoutBackend.ts b/src/bindings/utils/getLlamaWithoutBackend.ts new file mode 100644 index 00000000..f52c9589 --- /dev/null +++ b/src/bindings/utils/getLlamaWithoutBackend.ts @@ -0,0 +1,31 @@ +import {withLock} from "lifecycle-utils"; +import {getLlamaForOptions} from "../getLlama.js"; +import {LlamaLogLevel} from "../types.js"; +import {Llama} from "../Llama.js"; + +let sharedLlamaWithoutBackend: Llama | null = null; + +/** + * This is used to access various methods in the addon side without actually using a backend + */ +export async function getLlamaWithoutBackend() { + if (sharedLlamaWithoutBackend != null) + return sharedLlamaWithoutBackend; + + return await withLock(getLlamaWithoutBackend, "loadAddon", async () => { + if (sharedLlamaWithoutBackend != null) + return sharedLlamaWithoutBackend; + + sharedLlamaWithoutBackend = await getLlamaForOptions({ + gpu: false, + progressLogs: false, + logLevel: LlamaLogLevel.error, + build: "never", + usePrebuiltBinaries: true + }, { + skipLlamaInit: true + }); + + return sharedLlamaWithoutBackend; + }); +} diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index cbcdaccc..e46fdf70 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -4,6 +4,7 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; +import bytes from "bytes"; import {chatCommandHistoryFilePath, defaultChatSystemPrompt} from "../../config.js"; import {getIsInDocumentationMode} from "../../state.js"; import {ReplHistory} from "../../utils/ReplHistory.js"; @@ -373,6 +374,7 @@ async function RunChat({ console.info(`${chalk.yellow("Train context size:")} ${model.trainContextSize}`); console.info(`${chalk.yellow("Model type:")} ${model.typeDescription}`); + console.info(`${chalk.yellow("Model size:")} ${bytes(model.size)}`); console.info(`${chalk.yellow("BOS:")} ${bos}`); console.info(`${chalk.yellow("EOS:")} ${eos}`); console.info(`${chalk.yellow("Chat wrapper:")} ${chatWrapper.wrapperName}`); diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index c5468fee..b6cd8597 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -23,7 +23,10 @@ export type LlamaModelOptions = { /** only load the vocabulary, no weights */ vocabOnly?: boolean, - /** use mmap if possible */ + /** + * Use mmap if possible. + * Enabled by default in llama.cpp. + */ useMmap?: boolean, /** @@ -72,7 +75,9 @@ export class LlamaModel { gpuLayers, vocabOnly, useMmap, - useMlock, + useMlock: _llama.supportsMlock + ? useMlock + : undefined, onLoadProgress: onLoadProgress == null ? undefined : (loadPercentage: number) => { @@ -134,6 +139,15 @@ export class LlamaModel { return this._filename; } + /** + * Total model size in memory in bytes + */ + public get size() { + this._ensureNotDisposed(); + + return this._model.getModelSize(); + } + /** * Transform text into tokens that can be fed to the model * @param text - the text to tokenize diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 29e09b2e..7085413e 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,51 +1,186 @@ +import {Llama} from "../bindings/Llama.js"; +import {getLlamaWithoutBackend} from "../bindings/utils/getLlamaWithoutBackend.js"; import {getGgufMetadataLlmData} from "./utils/getGgufMetadataLlmData.js"; -import {GgufMetadata} from "./types/GgufMetadataTypes.js"; +import {GgufFileInfo} from "./types/GgufFileInfoTypes.js"; +import {GgufTensorInfo} from "./types/GgufTensorInfoTypes.js"; export class GgufInsights { - public readonly metadata: GgufMetadata; - public readonly metadataSize: number; + /** @internal */ private readonly _llama: Llama; + /** @internal */ private readonly _modelSize: number; + public readonly ggufFileInfo: GgufFileInfo; - public constructor({ - metadata, - metadataSize - }: { - metadata: GgufMetadata, - metadataSize: number - }) { - this.metadata = metadata; - this.metadataSize = metadataSize; + private constructor(ggufFileInfo: GgufFileInfo, llama: Llama) { + this._llama = llama; + this.ggufFileInfo = ggufFileInfo; + + this._modelSize = calculateTensorsSize(ggufFileInfo.tensorInfo ?? [], llama); } - /** - * fp16 k,v matrices - */ - public get kvMatrices() { - // 2 bytes each * 2 key and value - const llmData = getGgufMetadataLlmData(this.metadata); - return ( - 2 * 2 * - (llmData.context_length ?? 1) * - (llmData.block_count ?? 1) * - (llmData.embedding_length ?? 1) * - (llmData.attention?.head_count_kv ?? 1) / - (llmData.attention?.head_count ?? 1) - ); + public get totalLayers() { + const llmData = getGgufMetadataLlmData(this.ggufFileInfo.metadata); + + return (llmData.block_count ?? this._determineNumberOfLayersFromTensorInfo()) + 1; + } + + public get modelSize() { + return this._modelSize; + } + + public calculateModelResourceRequirements(gpuLayers: number) { + const {cpu, gpu} = this.getTensorLoadSplit(gpuLayers); + + return { + cpuRam: calculateTensorsSize(cpu, this._llama), + gpuVram: calculateTensorsSize(gpu, this._llama) + }; + } + + public getTensorLoadSplit(gpuLayers: number): { + cpu: GgufTensorInfo[], + gpu: GgufTensorInfo[] + } { + const tensorInfo = this.ggufFileInfo.tensorInfo ?? []; + + if (gpuLayers === 0) { + return { + cpu: tensorInfo, + gpu: [] + }; + } + + const gpuTensors: GgufTensorInfo[] = []; + const cpuTensors: GgufTensorInfo[] = []; + + for (const singleTensorInfo of tensorInfo) { + const {layerNumber} = parseTensorName(singleTensorInfo.name); + + if (layerNumber == null || layerNumber < gpuLayers) + gpuTensors.push(singleTensorInfo); + else + cpuTensors.push(singleTensorInfo); + } + + return { + cpu: cpuTensors, + gpu: gpuTensors + }; + } + + /** @internal */ + public _determineNumberOfLayersFromTensorInfo(): number { + const layerNumbers = new Set(); + + for (const singleTensorInfo of (this.ggufFileInfo.tensorInfo ?? [])) { + const {layerNumber} = parseTensorName(singleTensorInfo.name); + + if (layerNumber != null) + layerNumbers.add(layerNumber); + } + + return layerNumbers.size; } /** - * This amount is the overhead + tensors in memory + * @param ggufFileInfo + * @param llama - If you already have a `Llama` instance, pass it to reuse it for the `GgufInsights` instance. + * If you don't pass a `Llama` instance, a basic `Llama` instance is created as a fallback - it's a slim instance that + * doesn't instantiate a `llama.cpp` backend, so it won't utilize the GPU at all, and be shared with other `GgufInsights` instances + * that need a fallback `Llama` instance. */ - public get graphSize() { - // TODO: get this from the llama.cpp's graph calculations instead of - // estimating it's 1/6 * kv_cache_size * num_gqa - const llmData = getGgufMetadataLlmData(this.metadata); - return ( - (llmData.attention?.head_count_kv ?? 1) / - (llmData.attention?.head_count ?? 1) - ) * this.kvMatrices / 6; + public static async from(ggufFileInfo: GgufFileInfo, llama?: Llama) { + let resolvedLlama = llama; + if (resolvedLlama == null) + resolvedLlama = await getLlamaWithoutBackend(); + + return new GgufInsights(ggufFileInfo, resolvedLlama); } +} + +function parseTensorName(tensorName?: string): { + layerNumber: number | undefined +} { + if (tensorName == null) + return {layerNumber: undefined}; + + const layerTensorPrefix = "blk."; + if (!tensorName.startsWith(layerTensorPrefix)) + return {layerNumber: undefined}; + + const dotIndex = tensorName.indexOf(".", layerTensorPrefix.length); + const layerNumberString = tensorName.slice( + layerTensorPrefix.length, + dotIndex < 0 + ? tensorName.length + : dotIndex + ); + + const layerNumber = parseInt(layerNumberString); + if (Number.isFinite(layerNumber)) + return {layerNumber}; + + return {layerNumber: undefined}; +} + +function calculateTensorsSize(tensorsInfo: GgufTensorInfo[], llama: Llama) { + let size = 0; + for (const tensorInfo of tensorsInfo) + size += calculateTensorSize(tensorInfo, llama); + + return size; +} - public get VRAMUsage() { - return this.graphSize + this.kvMatrices + this.metadataSize; +function calculateTensorSize(tensor: GgufTensorInfo, llama: Llama) { + const typeSize = llama._bindings.getTypeSizeForGgmlType(tensor.ggmlType); + const blockSize = llama._bindings.getBlockSizeForGgmlType(tensor.ggmlType); + const ggmlMaxDims = llama._consts.ggmlMaxDims; + + if (typeSize == null || blockSize == null) + throw new Error("Invalid type or block size"); + + const {ne, nb} = getTensorNeAndNb(tensor, {typeSize, blockSize, ggmlMaxDims}); + + if (blockSize === 1) { + let totalBytes = typeSize; + for (let i = 0; i < ggmlMaxDims; i++) { + totalBytes += (ne[i] - 1) * nb[i]; + } + + return totalBytes; + } else { + let totalBytes = Math.floor((ne[0] * nb[0]) / blockSize); + for (let i = 1; i < ggmlMaxDims; i++) { + totalBytes += (ne[i] - 1) * nb[i]; + } + + return totalBytes; } } + +function getTensorNeAndNb(tensor: GgufTensorInfo, { + typeSize, blockSize, ggmlMaxDims +}: { + typeSize: number, blockSize: number, ggmlMaxDims: number +}) { + // number of elements + // source: `ggml_new_tensor_impl` in `ggml.c` + const ne = [ + ...tensor.dimensions, + ...(Array(Math.max(0, ggmlMaxDims - tensor.dimensions.length)).fill(1)) + ].slice(0, ggmlMaxDims); + + // number of bytes + // source: `ggml_new_tensor_impl` in `ggml.c` + const nb = [ + typeSize, + Math.floor(typeSize * (ne[0] / blockSize)), + ...Array(ggmlMaxDims - 2).fill(0) + ]; + for (let i = 2; i < ggmlMaxDims; i++) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + + return { + ne, + nb + }; +} diff --git a/test/modelDependent/functionary/gguf.test.ts b/test/modelDependent/functionary/gguf.test.ts deleted file mode 100644 index df7832e4..00000000 --- a/test/modelDependent/functionary/gguf.test.ts +++ /dev/null @@ -1,62 +0,0 @@ -import {describe, expect, it, test} from "vitest"; -import {GgufFsFileReader} from "../../../src/gguf/fileReaders/GgufFsFileReader.js"; -import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; -import {getModelFile} from "../../utils/modelFiles.js"; -import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; -import {getTestLlama} from "../../utils/getTestLlama.js"; -import {readGgufFileInfo} from "../../../src/gguf/readGgufFileInfo.js"; -import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; - -describe("GGUF Parser", async () => { - const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); - - test("Magic should be GGUF local model", async () => { - const fileReader = new GgufFsFileReader({filePath: modelPath}); - const magic = await fileReader.readByteRange(0, 4); - const magicText = String.fromCharCode(...magic); - - expect(magicText).toBe("GGUF"); - }); - - it("should parse local gguf model", async () => { - const fileReader = new GgufFsFileReader({filePath: modelPath}); - - const metadata = await parseGguf({ - fileReader: fileReader - }); - expect(metadata.tensorInfo!.length).to.be.eql(Number(metadata.tensorCount)); - - expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); - }); - - it("should calculate GGUF VRAM Usage", async () => { - const fileReader = new GgufFsFileReader({filePath: modelPath}); - - const metadata = await parseGguf({ - fileReader: fileReader - }); - - const ggufInsights = new GgufInsights(metadata); - - const llama = await getTestLlama(); - const model = await llama.loadModel({ - modelPath: modelPath - }); - - const usedRam = llama.getVramState().used; - - expect(ggufInsights.VRAMUsage).toMatchInlineSnapshot("4474643028.666667"); - expect(usedRam).to.be.gte(3.5 * Math.pow(1024, 3)); - expect(usedRam).to.be.lte(4.5 * Math.pow(1024, 3)); - await model.dispose(); - }); - - it("should fetch GGUF metadata", async () => { - const ggufMetadataParseResult = await readGgufFileInfo(modelPath); - - expect(simplifyGgufInfoForTestSnapshot(ggufMetadataParseResult)).toMatchSnapshot(); - - const insights = new GgufInsights(ggufMetadataParseResult); - expect(insights.VRAMUsage).toMatchInlineSnapshot("4474643028.666667"); - }); -}); diff --git a/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap b/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap similarity index 98% rename from test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap rename to test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap index ffd58170..3d356afc 100644 --- a/test/modelDependent/functionary/__snapshots__/gguf.test.ts.snap +++ b/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap @@ -1,6 +1,6 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html -exports[`GGUF Parser > should fetch GGUF metadata 1`] = ` +exports[`gguf > parser > should fetch GGUF metadata 1`] = ` { "metadata": { "general": { @@ -146,7 +146,7 @@ exports[`GGUF Parser > should fetch GGUF metadata 1`] = ` } `; -exports[`GGUF Parser > should parse local gguf model 1`] = ` +exports[`gguf > parser > should parse local gguf model 1`] = ` { "metadata": { "general": { diff --git a/test/modelDependent/functionary/gguf/ggufInsights.test.ts b/test/modelDependent/functionary/gguf/ggufInsights.test.ts new file mode 100644 index 00000000..704dd654 --- /dev/null +++ b/test/modelDependent/functionary/gguf/ggufInsights.test.ts @@ -0,0 +1,81 @@ +import {describe, expect, test} from "vitest"; +import {getModelFile} from "../../../utils/modelFiles.js"; +import {GgufInsights} from "../../../../src/gguf/GgufInsights.js"; +import {getTestLlama} from "../../../utils/getTestLlama.js"; +import {readGgufFileInfo} from "../../../../src/gguf/readGgufFileInfo.js"; +import {getGgufMetadataLlmData} from "../../../../src/gguf/utils/getGgufMetadataLlmData.js"; + +describe("gguf", async () => { + describe("insights", async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + + test("determine the number of layers from the tensor info", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + const insights = await GgufInsights.from(ggufMetadataParseResult, llama); + const llmData = getGgufMetadataLlmData(ggufMetadataParseResult.metadata); + + expect(insights._determineNumberOfLayersFromTensorInfo()).to.be.eql(llmData.block_count); + }); + + test("calculated model size stays the same", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); + expect(ggufInsights.modelSize).toMatchInlineSnapshot("4108204160"); + }); + + test("predicted VRAM usage should match actual VRAM usage", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); + + const initialVramUsage = llama.getVramState().used; + const model = await llama.loadModel({ + modelPath: modelPath + }); + const currentVramUsage = llama.getVramState().used; + + const vramUsageDiff = currentVramUsage - initialVramUsage; + + const s100MB = 100 * Math.pow(1024, 2); + const s5MB = 5 * Math.pow(1024, 2); + + expect(ggufInsights.modelSize).toMatchInlineSnapshot("4108204160"); + expect(Math.abs(vramUsageDiff - ggufInsights.modelSize)).to.be.lte(s100MB); + + const calculationDiffWithActual = ggufInsights.modelSize - model.size; + expect(Math.abs(calculationDiffWithActual)).to.be.lte(s5MB); // tolerate such a small difference + + if (calculationDiffWithActual !== 0) + console.warn("Model size calculation is off by", calculationDiffWithActual, "bytes"); + + await model.dispose(); + }); + + test("predicted VRAM usage should match actual VRAM usage when using gpuLayers", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); + + const initialVramUsage = llama.getVramState().used; + const model = await llama.loadModel({ + modelPath: modelPath, + gpuLayers: 16 + }); + const currentVramUsage = llama.getVramState().used; + + const vramUsageDiff = currentVramUsage - initialVramUsage; + + const s100MB = 100 * Math.pow(1024, 2); + const calculatedVramUsage = ggufInsights.calculateModelResourceRequirements(16).gpuVram; + + expect(Math.abs(vramUsageDiff - calculatedVramUsage)).to.be.lte(s100MB); + + await model.dispose(); + }); + }); +}); diff --git a/test/modelDependent/functionary/gguf/ggufParser.test.ts b/test/modelDependent/functionary/gguf/ggufParser.test.ts new file mode 100644 index 00000000..42fb7dbf --- /dev/null +++ b/test/modelDependent/functionary/gguf/ggufParser.test.ts @@ -0,0 +1,37 @@ +import {describe, expect, it, test} from "vitest"; +import {GgufFsFileReader} from "../../../../src/gguf/fileReaders/GgufFsFileReader.js"; +import {parseGguf} from "../../../../src/gguf/parser/parseGguf.js"; +import {getModelFile} from "../../../utils/modelFiles.js"; +import {readGgufFileInfo} from "../../../../src/gguf/readGgufFileInfo.js"; +import {simplifyGgufInfoForTestSnapshot} from "../../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; + +describe("gguf", async () => { + describe("parser", async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + + test("Magic should be GGUF local model", async () => { + const fileReader = new GgufFsFileReader({filePath: modelPath}); + const magic = await fileReader.readByteRange(0, 4); + const magicText = String.fromCharCode(...magic); + + expect(magicText).toBe("GGUF"); + }); + + it("should parse local gguf model", async () => { + const fileReader = new GgufFsFileReader({filePath: modelPath}); + + const metadata = await parseGguf({ + fileReader: fileReader + }); + expect(metadata.tensorInfo!.length).to.be.eql(Number(metadata.tensorCount)); + + expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); + }); + + it("should fetch GGUF metadata", async () => { + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + expect(simplifyGgufInfoForTestSnapshot(ggufMetadataParseResult)).toMatchSnapshot(); + }); + }); +}); diff --git a/test/standalone/gguf/__snapshots__/gguf.test.ts.snap b/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap similarity index 96% rename from test/standalone/gguf/__snapshots__/gguf.test.ts.snap rename to test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap index fb6b6e60..0c07ab18 100644 --- a/test/standalone/gguf/__snapshots__/gguf.test.ts.snap +++ b/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap @@ -1,6 +1,6 @@ // Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html -exports[`GGUF Parser > should parse remote gguf model 1`] = ` +exports[`gguf > parser > should parse remote gguf model 1`] = ` { "metadata": { "falcon": { @@ -121,7 +121,7 @@ exports[`GGUF Parser > should parse remote gguf model 1`] = ` } `; -exports[`GGUF Parser > should parse remote gguf model without tensor info 1`] = ` +exports[`gguf > parser > should parse remote gguf model without tensor info 1`] = ` { "metadata": { "falcon": { diff --git a/test/standalone/gguf/gguf.test.ts b/test/standalone/gguf/gguf.test.ts deleted file mode 100644 index b6ce2622..00000000 --- a/test/standalone/gguf/gguf.test.ts +++ /dev/null @@ -1,38 +0,0 @@ -import {describe, expect, it, test} from "vitest"; -import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; -import {GgufNetworkFetchFileReader} from "../../../src/gguf/fileReaders/GgufNetworkFetchFileReader.js"; -import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; - -const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; - -describe("GGUF Parser", async () => { - test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { - const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - - const magic = await fileReader.readByteRange(0, 4); - const magicText = String.fromCharCode(...magic); - - expect(magicText).toBe("GGUF"); - }); - - it("should parse remote gguf model", async () => { - const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - - const metadata = await parseGguf({ - fileReader: fileReader - }); - - expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); - }); - - it("should parse remote gguf model without tensor info", async () => { - const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); - - const metadata = await parseGguf({ - fileReader: fileReader, - readTensorInfo: false - }); - - expect(simplifyGgufInfoForTestSnapshot(metadata)).toMatchSnapshot(); - }); -}); diff --git a/test/standalone/gguf/ggufStandaloneParser.test.ts b/test/standalone/gguf/ggufStandaloneParser.test.ts new file mode 100644 index 00000000..f90d2e0e --- /dev/null +++ b/test/standalone/gguf/ggufStandaloneParser.test.ts @@ -0,0 +1,43 @@ +import {describe, expect, it, test} from "vitest"; +import {parseGguf} from "../../../src/gguf/parser/parseGguf.js"; +import {GgufNetworkFetchFileReader} from "../../../src/gguf/fileReaders/GgufNetworkFetchFileReader.js"; +import {simplifyGgufInfoForTestSnapshot} from "../../utils/helpers/simplifyGgufInfoForTestSnapshot.js"; + +const remoteGGUFModel = "https://huggingface.co/TheBloke/Falcon-180B-Chat-GGUF/resolve/main/falcon-180b-chat.Q6_K.gguf-split-a?download=true"; + +describe("gguf", async () => { + describe("parser", async () => { + test("Magic should be GGUF remote model", {timeout: 1000 * 60 * 10}, async () => { + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); + + const magic = await fileReader.readByteRange(0, 4); + const magicText = String.fromCharCode(...magic); + + expect(magicText) + .toBe("GGUF"); + }); + + it("should parse remote gguf model", async () => { + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); + + const metadata = await parseGguf({ + fileReader: fileReader + }); + + expect(simplifyGgufInfoForTestSnapshot(metadata)) + .toMatchSnapshot(); + }); + + it("should parse remote gguf model without tensor info", async () => { + const fileReader = new GgufNetworkFetchFileReader({url: remoteGGUFModel}); + + const metadata = await parseGguf({ + fileReader: fileReader, + readTensorInfo: false + }); + + expect(simplifyGgufInfoForTestSnapshot(metadata)) + .toMatchSnapshot(); + }); + }); +}); diff --git a/test/utils/modelFiles.ts b/test/utils/modelFiles.ts index aebdfe37..98479913 100644 --- a/test/utils/modelFiles.ts +++ b/test/utils/modelFiles.ts @@ -4,6 +4,7 @@ import {downloadFile, downloadSequence} from "ipull"; import fs from "fs-extra"; import chalk from "chalk"; import withStatusLogs from "../../src/utils/withStatusLogs.js"; +import {withLockfile} from "../../src/utils/withLockfile.js"; const __dirname = path.dirname(fileURLToPath(import.meta.url)); @@ -22,22 +23,28 @@ export async function getModelFile(modelName: keyof typeof supportedModels) { if (await fs.pathExists(modelFilePath)) return modelFilePath; + await fs.ensureDir(modelsFolder); + return await withStatusLogs({ loading: chalk.blue(`Downloading model "${modelName}"`), success: chalk.blue(`Downloaded model "${modelName}"`), fail: chalk.blue(`Failed to download model "${modelName}"`) }, async () => { - const modelUrl = supportedModels[modelName]; + return await withLockfile({ + resourcePath: modelFilePath + }, async () => { + const modelUrl = supportedModels[modelName]; - const downloader = await downloadFile({ - url: modelUrl, - directory: path.dirname(modelFilePath), - fileName: path.basename(modelFilePath), - cliProgress: true - }); - await downloader.download(); + const downloader = await downloadFile({ + url: modelUrl, + directory: path.dirname(modelFilePath), + fileName: path.basename(modelFilePath), + cliProgress: true + }); + await downloader.download(); - return modelFilePath; + return modelFilePath; + }); }); } From 6b9d2b9a2657a780f40cb018d8b53a7e68255fbe Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 21 Mar 2024 21:30:43 +0200 Subject: [PATCH 21/52] test: skip VRAM tests when running on a machine without a GPU --- .../functionary/gguf/ggufInsights.test.ts | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/test/modelDependent/functionary/gguf/ggufInsights.test.ts b/test/modelDependent/functionary/gguf/ggufInsights.test.ts index 704dd654..17d85f50 100644 --- a/test/modelDependent/functionary/gguf/ggufInsights.test.ts +++ b/test/modelDependent/functionary/gguf/ggufInsights.test.ts @@ -26,10 +26,13 @@ describe("gguf", async () => { expect(ggufInsights.modelSize).toMatchInlineSnapshot("4108204160"); }); - test("predicted VRAM usage should match actual VRAM usage", async () => { + test("predicted VRAM usage should match actual VRAM usage", async (context) => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + if (llama.gpu === false) + return context.skip(); + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); const initialVramUsage = llama.getVramState().used; @@ -55,10 +58,13 @@ describe("gguf", async () => { await model.dispose(); }); - test("predicted VRAM usage should match actual VRAM usage when using gpuLayers", async () => { + test("predicted VRAM usage should match actual VRAM usage when using gpuLayers", async (context) => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + if (llama.gpu === false) + return context.skip(); + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); const initialVramUsage = llama.getVramState().used; From cad3fd29941cb69522d49a5bfbfa8e3808eba6f6 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Sat, 23 Mar 2024 23:59:14 +0200 Subject: [PATCH 22/52] feat: calculate context VRAM usage based on header tensor info --- .gitignore | 2 + llama/addon.cpp | 23 +- src/bindings/AddonTypes.ts | 9 +- src/evaluator/LlamaContext/LlamaContext.ts | 11 + src/gguf/GgufInsights.ts | 179 ++++++++- src/gguf/parser/parseGguf.ts | 3 + src/gguf/types/GgufFileInfoTypes.ts | 6 +- ....ts => getGgufMetadataArchitectureData.ts} | 2 +- .../functionary/gguf/ggufInsights.test.ts | 341 +++++++++++++++++- 9 files changed, 548 insertions(+), 28 deletions(-) rename src/gguf/utils/{getGgufMetadataLlmData.ts => getGgufMetadataArchitectureData.ts} (75%) diff --git a/.gitignore b/.gitignore index 9a82b9c8..d6a148e0 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,8 @@ node_modules /llama/lastBuild.json /llama/gitRelease.bundle /llama/.temp +/llama/.idea +/llama/cmake-build-debug /llama/localBuilds /llama/Release /llama/Debug diff --git a/llama/addon.cpp b/llama/addon.cpp index 4db22869..060635c4 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -827,6 +827,10 @@ class AddonContext : public Napi::ObjectWrap { context_params.n_ubatch = context_params.n_batch; // the batch queue is managed in the JS side, so there's no need for managing it on the C++ side } + if (options.Has("sequences")) { + context_params.n_seq_max = options.Get("sequences").As().Uint32Value(); + } + if (options.Has("embeddings")) { context_params.embeddings = options.Get("embeddings").As().Value(); } @@ -1044,6 +1048,15 @@ class AddonContext : public Napi::ObjectWrap { return result; } + Napi::Value GetStateSize(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + return Napi::Number::From(info.Env(), llama_get_state_size(ctx)); + } + Napi::Value PrintTimings(const Napi::CallbackInfo& info) { llama_print_timings(ctx); llama_reset_timings(ctx); @@ -1068,6 +1081,7 @@ class AddonContext : public Napi::ObjectWrap { InstanceMethod("sampleToken", &AddonContext::SampleToken), InstanceMethod("acceptGrammarEvaluationStateToken", &AddonContext::AcceptGrammarEvaluationStateToken), InstanceMethod("getEmbedding", &AddonContext::GetEmbedding), + InstanceMethod("getStateSize", &AddonContext::GetStateSize), InstanceMethod("printTimings", &AddonContext::PrintTimings), InstanceMethod("dispose", &AddonContext::Dispose), } @@ -1469,7 +1483,7 @@ Napi::Value addonGetBlockSizeForGgmlType(const Napi::CallbackInfo& info) { } const auto blockSize = ggml_blck_size(static_cast(ggmlType)); - + return Napi::Number::New(info.Env(), blockSize); } @@ -1481,13 +1495,18 @@ Napi::Value addonGetTypeSizeForGgmlType(const Napi::CallbackInfo& info) { } const auto typeSize = ggml_type_size(static_cast(ggmlType)); - + return Napi::Number::New(info.Env(), typeSize); } Napi::Value addonGetConsts(const Napi::CallbackInfo& info) { Napi::Object consts = Napi::Object::New(info.Env()); consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS)); + consts.Set("ggmlTypeF16Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F16))); + consts.Set("ggmlTypeF32Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F32))); + consts.Set("llamaMaxRngState", Napi::Number::New(info.Env(), LLAMA_MAX_RNG_STATE)); + consts.Set("llamaPosSize", Napi::Number::New(info.Env(), sizeof(llama_pos))); + consts.Set("llamaSeqIdSize", Napi::Number::New(info.Env(), sizeof(llama_seq_id))); return consts; } diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index bd28f903..7cff953c 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -18,6 +18,7 @@ export type BindingModule = { seed?: number, contextSize?: number, batchSize?: number, + sequences?: number, logitsAll?: boolean, embeddings?: boolean, threads?: number @@ -39,7 +40,12 @@ export type BindingModule = { getBlockSizeForGgmlType(ggmlType: number): number | undefined, getTypeSizeForGgmlType(ggmlType: number): number | undefined, getConsts(): { - ggmlMaxDims: number + ggmlMaxDims: number, + ggmlTypeF16Size: number, + ggmlTypeF32Size: number, + llamaMaxRngState: number, + llamaPosSize: number, + llamaSeqIdSize: number }, setLogger(logger: (level: number, message: string) => void): void, setLoggerLogLevel(level: number): void, @@ -109,6 +115,7 @@ export type AddonContext = { acceptGrammarEvaluationStateToken(grammarEvaluationState: AddonGrammarEvaluationState, token: Token): void, getEmbedding(inputTokensLength: number): Float64Array, + getStateSize(): number, printTimings(): void }; diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index 3126cca3..9a05a32d 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -70,6 +70,7 @@ export class LlamaContext { seed: seed != null ? Math.max(-1, Math.floor(seed)) : undefined, contextSize: this._contextSize * this._totalSequences, // each sequence needs its own of cells batchSize: this._batchSize, + sequences: this._totalSequences, threads: Math.max(0, Math.floor(threads)), embeddings: _embeddings, noSeed: _noSeed @@ -129,6 +130,16 @@ export class LlamaContext { return this._batchSize; } + /** + * The actual size of the state in the memory in bytes. + * This value is provided by `llama.cpp` and doesn't include all the memory overhead of the context. + */ + public get stateSize() { + this._ensureNotDisposed(); + + return this._ctx.getStateSize(); + } + public getAllocatedContextSize(): number { this._ensureNotDisposed(); diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 7085413e..05fb55cc 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,12 +1,18 @@ import {Llama} from "../bindings/Llama.js"; import {getLlamaWithoutBackend} from "../bindings/utils/getLlamaWithoutBackend.js"; -import {getGgufMetadataLlmData} from "./utils/getGgufMetadataLlmData.js"; import {GgufFileInfo} from "./types/GgufFileInfoTypes.js"; import {GgufTensorInfo} from "./types/GgufTensorInfoTypes.js"; +import {GgufArchitectureType} from "./types/GgufMetadataTypes.js"; + +export type GgufInsightsResourceRequirements = { + cpuRam: number, + gpuVram: number +}; export class GgufInsights { /** @internal */ private readonly _llama: Llama; /** @internal */ private readonly _modelSize: number; + /** @internal */ private _totalLayers: number | null = null; public readonly ggufFileInfo: GgufFileInfo; private constructor(ggufFileInfo: GgufFileInfo, llama: Llama) { @@ -17,17 +23,21 @@ export class GgufInsights { } public get totalLayers() { - const llmData = getGgufMetadataLlmData(this.ggufFileInfo.metadata); + if (this._totalLayers != null) + return this._totalLayers; + + const outputLayers = 1; + this._totalLayers = this._getFileLayers() + outputLayers; - return (llmData.block_count ?? this._determineNumberOfLayersFromTensorInfo()) + 1; + return this._totalLayers; } public get modelSize() { return this._modelSize; } - public calculateModelResourceRequirements(gpuLayers: number) { - const {cpu, gpu} = this.getTensorLoadSplit(gpuLayers); + public estimateModelResourceRequirements({gpuLayers}: {gpuLayers: number}): GgufInsightsResourceRequirements { + const {cpu, gpu} = this._getTensorResourceSplit(gpuLayers); return { cpuRam: calculateTensorsSize(cpu, this._llama), @@ -35,7 +45,120 @@ export class GgufInsights { }; } - public getTensorLoadSplit(gpuLayers: number): { + /** + * Estimates the memory required to create a context of the given parameters based on the implementation details of `llama.cpp`. + * The calculation doesn't include a precise estimation of the graph overhead memory, so it uses a rough estimate for that. + * The estimation for the graph overhead memory will be improved in the future to be more precise, but it's good enough for now. + */ + public estimateContextResourceRequirements({ + contextSize, batchSize, modelGpuLayers, sequences, isEmbeddingContext = false, includeGraphOverhead = true + }: { + contextSize: number, batchSize: number, modelGpuLayers: number, sequences: number, isEmbeddingContext?: boolean, + includeGraphOverhead?: boolean + }): GgufInsightsResourceRequirements { + const totalLayers = this.totalLayers; + const finalGpuLayers = Math.max(0, Math.min(modelGpuLayers ?? totalLayers, totalLayers)); + const finalCpuLayers = totalLayers - finalGpuLayers; + const llmData = this.ggufFileInfo.architectureMetadata; + + const vocabularySize = llmData.vocab_size ?? this.ggufFileInfo.metadata.tokenizer.ggml.tokens.length; + const logitsSize = vocabularySize * batchSize; + const embedSize = isEmbeddingContext + ? (llmData.embedding_length ?? 0) * batchSize + : 0; + + const sizeTBytes = 8; // sizeof(size_t) + const floatBytes = 4; // sizeof(float) + const uint32TBytes = 4; // sizeof(uint32_t) + + // source: `llama_get_state_size` in `llama.cpp` + const sRngSize = sizeTBytes; + const sRng = this._llama._consts.llamaMaxRngState; + const sLogitsSize = sizeTBytes; + const sLogits = logitsSize * floatBytes; + const sEmbeddingSize = sizeTBytes; + const sEmbedding = embedSize * floatBytes; + const sKvBufSize = sizeTBytes; + const sKvHead = uint32TBytes; + const sKvSize = uint32TBytes; + const sKvUsed = uint32TBytes; + // const sKv = this._estimateKvByteSize(contextSize); + const sKvCell = this._llama._consts.llamaPosSize + sizeTBytes + this._llama._consts.llamaSeqIdSize; + const kvSelfLength = this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.mamba + ? Math.max(1, sequences) + : contextSize; + const sKvCells = kvSelfLength * sKvCell; + + const overheadMemory = ( + sRngSize + + sRng + + sLogitsSize + + sLogits + + sEmbeddingSize + + sEmbedding + + sKvBufSize + + sKvHead + + sKvSize + + sKvUsed + + sKvCells + ); + + // Estimates the memory allocated by `ggml_backend_sched_reserve` in `llama_new_context_with_model` in `llama.cpp`. + // If you read this line and have better insights on how to estimate this memory, please open a PR to improve it :) + const estimateGraphOverheadMemory = () => { + const tensorInfo = this.ggufFileInfo.tensorInfo ?? []; + + const totalDimensions = tensorInfo.length === 0 + ? this.totalLayers * ( + ( + (this.ggufFileInfo.architectureMetadata.embedding_length ?? 0) + + (this.ggufFileInfo.architectureMetadata.feed_forward_length ?? 0) + ) / 2 + ) + : tensorInfo.reduce((res, tensor) => { + return res + tensor.dimensions.reduce((res: number, dim) => res + Number(dim), 0); + }, 0); + + // magic numbers for estimation. will be improved in the future + return totalDimensions * 77.655 * (contextSize / 4096); + }; + + const graphOverheadMemory = !includeGraphOverhead + ? 0 + : estimateGraphOverheadMemory(); + + const usingGpu = finalGpuLayers !== 0; + + const cpuRam = ( + !usingGpu + ? (overheadMemory + graphOverheadMemory) + : 0 + ) + + this._estimateKvMemorySizeInBytes(contextSize, finalCpuLayers); + const gpuVram = usingGpu + ? ( + overheadMemory + + graphOverheadMemory + + this._estimateKvMemorySizeInBytes( + contextSize, + finalGpuLayers < totalLayers + ? (finalGpuLayers + 1) + : finalGpuLayers + ) + ) + : 0; + + return { + cpuRam, + gpuVram + }; + } + + /** + * Get the split tensor resources for CPU and GPU based on the number of GPU layers + * @internal + */ + public _getTensorResourceSplit(gpuLayers: number): { cpu: GgufTensorInfo[], gpu: GgufTensorInfo[] } { @@ -80,6 +203,50 @@ export class GgufInsights { return layerNumbers.size; } + /** @internal */ + public _getFileLayers() { + return this.ggufFileInfo.architectureMetadata.block_count ?? this._determineNumberOfLayersFromTensorInfo(); + } + + /** @internal */ + public _estimateKvMemorySizeInBytes(contextSize: number, layers: number) { + // source: `llama_kv_cache_init` in `llama.cpp` + const nHead = this.ggufFileInfo.architectureMetadata.attention?.head_count ?? 0; + const nEmbd = this.ggufFileInfo.architectureMetadata.embedding_length ?? 0; + const nEmbdHeadK = this.ggufFileInfo.architectureMetadata.attention?.key_length ?? ((nHead == 0) ? 0 : (nEmbd / nHead)); + const nHeadKv = this.ggufFileInfo.architectureMetadata.attention?.head_count_kv ?? nHead; + const modelNEmbdKGqa = nEmbdHeadK * nHeadKv; + + const ssmDConv = this.ggufFileInfo.architectureMetadata.ssm?.conv_kernel ?? 0; + const ssmDInner = this.ggufFileInfo.architectureMetadata.ssm?.inner_size ?? 0; + const modelNEmbdKS = (ssmDConv > 0 ? (ssmDConv - 1) : 0) * ssmDInner; + + const nEmbdHeadV = this.ggufFileInfo.architectureMetadata.attention?.value_length ?? ((nHead == 0) ? 0 : nEmbd / nHead); + const modelNEmbdVGqa = nEmbdHeadV * nHeadKv; + + const ssmDState = this.ggufFileInfo.architectureMetadata.ssm?.state_size ?? 0; + const modelNEmbdVS = ssmDState * ssmDInner; + + const totalNEmbdKGqa = modelNEmbdKGqa + modelNEmbdKS; + const totalNEmbdVGqa = modelNEmbdVGqa + modelNEmbdVS; + + const keyTypeSize = this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.mamba + // if `type_k` of `llama_context_params` changes to be configurable in `LlamaContext`, + // this would have to depend on that value + ? this._llama._consts.ggmlTypeF32Size + : this._llama._consts.ggmlTypeF16Size; + const valueTypeSize = this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.mamba + // if `type_v` of `llama_context_params` changes to be configurable in `LlamaContext`, + // this would have to depend on that value + ? this._llama._consts.ggmlTypeF32Size + : this._llama._consts.ggmlTypeF16Size; + + const keyTensorsSize = layers * totalNEmbdKGqa * contextSize * keyTypeSize; + const valueTensorsSize = layers * totalNEmbdVGqa * contextSize * valueTypeSize; + + return keyTensorsSize + valueTensorsSize; + } + /** * @param ggufFileInfo * @param llama - If you already have a `Llama` instance, pass it to reuse it for the `GgufInsights` instance. diff --git a/src/gguf/parser/parseGguf.ts b/src/gguf/parser/parseGguf.ts index 473553f1..d1bd71cb 100644 --- a/src/gguf/parser/parseGguf.ts +++ b/src/gguf/parser/parseGguf.ts @@ -4,6 +4,7 @@ import {UnsupportedError} from "../../utils/UnsupportedError.js"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; import {GgufFileReader, valueTypeToBytesToRead} from "../fileReaders/GgufFileReader.js"; import {GgufFileInfo, GgufVersionParserOptions, GgufVersionParserResult} from "../types/GgufFileInfoTypes.js"; +import {getGgufMetadataArchitectureData} from "../utils/getGgufMetadataArchitectureData.js"; import {GgufV2Parser} from "./GgufV2Parser.js"; import {GgufV3Parser} from "./GgufV3Parser.js"; @@ -31,11 +32,13 @@ export async function parseGguf({ readOffset, logWarnings }); + const architectureMetadata = getGgufMetadataArchitectureData(gguifInfo.metadata); return { version: magicAndVersion.version, tensorCount: gguifInfo.tensorCount, metadata: gguifInfo.metadata, + architectureMetadata: architectureMetadata, tensorInfo: gguifInfo.tensorInfo, metadataSize: gguifInfo.metadataSize, tensorInfoSize: gguifInfo.tensorInfoSize diff --git a/src/gguf/types/GgufFileInfoTypes.ts b/src/gguf/types/GgufFileInfoTypes.ts index f1a8e789..d48efb6f 100644 --- a/src/gguf/types/GgufFileInfoTypes.ts +++ b/src/gguf/types/GgufFileInfoTypes.ts @@ -1,6 +1,7 @@ import {GgufReadOffset} from "../utils/GgufReadOffset.js"; import {GgufFileReader} from "../fileReaders/GgufFileReader.js"; -import {GgufMetadata} from "./GgufMetadataTypes.js"; +import {MergeOptionalUnionTypes} from "../../utils/mergeUnionTypes.js"; +import {GgufArchitectureType, GgufMetadata} from "./GgufMetadataTypes.js"; import {GgufTensorInfo} from "./GgufTensorInfoTypes.js"; export type MetadataValue = string | number | bigint | boolean | MetadataValue[]; @@ -15,6 +16,9 @@ export type GgufFileInfo = { metadata: GgufMetadata, metadataSize: number, + /** Same value as `metadata[metadata.general.architecture]`, but with merged types for convenience */ + architectureMetadata: MergeOptionalUnionTypes>, + /** can be null if `readTensorInfo` is set to `false` */ tensorInfo?: GgufTensorInfo[], diff --git a/src/gguf/utils/getGgufMetadataLlmData.ts b/src/gguf/utils/getGgufMetadataArchitectureData.ts similarity index 75% rename from src/gguf/utils/getGgufMetadataLlmData.ts rename to src/gguf/utils/getGgufMetadataArchitectureData.ts index c9ae20b2..475eb10a 100644 --- a/src/gguf/utils/getGgufMetadataLlmData.ts +++ b/src/gguf/utils/getGgufMetadataArchitectureData.ts @@ -1,7 +1,7 @@ import {GgufArchitectureType, GgufMetadata} from "../types/GgufMetadataTypes.js"; import {MergeOptionalUnionTypes} from "../../utils/mergeUnionTypes.js"; -export function getGgufMetadataLlmData(ggufMetadata: GgufMetadata): ( +export function getGgufMetadataArchitectureData(ggufMetadata: GgufMetadata): ( GgufArchitectureType extends T ? MergeOptionalUnionTypes> : GgufMetadata[T] diff --git a/test/modelDependent/functionary/gguf/ggufInsights.test.ts b/test/modelDependent/functionary/gguf/ggufInsights.test.ts index 17d85f50..bc3787b5 100644 --- a/test/modelDependent/functionary/gguf/ggufInsights.test.ts +++ b/test/modelDependent/functionary/gguf/ggufInsights.test.ts @@ -1,9 +1,9 @@ import {describe, expect, test} from "vitest"; +import bytes from "bytes"; import {getModelFile} from "../../../utils/modelFiles.js"; -import {GgufInsights} from "../../../../src/gguf/GgufInsights.js"; +import {GgufInsights, GgufInsightsResourceRequirements} from "../../../../src/gguf/GgufInsights.js"; import {getTestLlama} from "../../../utils/getTestLlama.js"; import {readGgufFileInfo} from "../../../../src/gguf/readGgufFileInfo.js"; -import {getGgufMetadataLlmData} from "../../../../src/gguf/utils/getGgufMetadataLlmData.js"; describe("gguf", async () => { describe("insights", async () => { @@ -13,9 +13,8 @@ describe("gguf", async () => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); const insights = await GgufInsights.from(ggufMetadataParseResult, llama); - const llmData = getGgufMetadataLlmData(ggufMetadataParseResult.metadata); - expect(insights._determineNumberOfLayersFromTensorInfo()).to.be.eql(llmData.block_count); + expect(insights._determineNumberOfLayersFromTensorInfo()).to.be.eql(ggufMetadataParseResult.architectureMetadata.block_count); }); test("calculated model size stays the same", async () => { @@ -26,34 +25,104 @@ describe("gguf", async () => { expect(ggufInsights.modelSize).toMatchInlineSnapshot("4108204160"); }); - test("predicted VRAM usage should match actual VRAM usage", async (context) => { + test("estimated model memory footprint stays the same", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 0}))).toMatchInlineSnapshot(` + { + "cpuRam": "3.83GB", + "gpuVram": "0B", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 1}))).toMatchInlineSnapshot(` + { + "cpuRam": "3.54GB", + "gpuVram": "289.92MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 8}))).toMatchInlineSnapshot(` + { + "cpuRam": "2.74GB", + "gpuVram": "1.08GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 16}))).toMatchInlineSnapshot(` + { + "cpuRam": "1.83GB", + "gpuVram": "2GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 24}))).toMatchInlineSnapshot(` + { + "cpuRam": "936.25MB", + "gpuVram": "2.91GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 32}))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "3.83GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateModelResourceRequirements({gpuLayers: 33}))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "3.83GB", + } + `); + }); + + test("predicted VRAM usage should match actual VRAM usage", async (testContext) => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); if (llama.gpu === false) - return context.skip(); + return testContext.skip(); const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); - const initialVramUsage = llama.getVramState().used; + const initialModelVramUsage = llama.getVramState().used; const model = await llama.loadModel({ - modelPath: modelPath + modelPath: modelPath, + gpuLayers: ggufInsights.totalLayers }); - const currentVramUsage = llama.getVramState().used; + const currentModelVramUsage = llama.getVramState().used; - const vramUsageDiff = currentVramUsage - initialVramUsage; + const modelVramUsageDiff = currentModelVramUsage - initialModelVramUsage; const s100MB = 100 * Math.pow(1024, 2); const s5MB = 5 * Math.pow(1024, 2); - expect(ggufInsights.modelSize).toMatchInlineSnapshot("4108204160"); - expect(Math.abs(vramUsageDiff - ggufInsights.modelSize)).to.be.lte(s100MB); + const estimatedModelVramUsage = ggufInsights.estimateModelResourceRequirements({gpuLayers: ggufInsights.totalLayers}).gpuVram; + expect(bytes(estimatedModelVramUsage)).toMatchInlineSnapshot('"3.83GB"'); + expect(Math.abs(modelVramUsageDiff - estimatedModelVramUsage)).to.be.lte(s100MB); + + const modelEstimationDiffWithActual = estimatedModelVramUsage - model.size; + expect(Math.abs(modelEstimationDiffWithActual)).to.be.lte(s5MB); // tolerate such a small difference - const calculationDiffWithActual = ggufInsights.modelSize - model.size; - expect(Math.abs(calculationDiffWithActual)).to.be.lte(s5MB); // tolerate such a small difference + if (modelEstimationDiffWithActual !== 0) + console.warn("Model size estimation is off by", modelEstimationDiffWithActual, "bytes"); - if (calculationDiffWithActual !== 0) - console.warn("Model size calculation is off by", calculationDiffWithActual, "bytes"); + const initialContextVramUsage = llama.getVramState().used; + const context = await model.createContext({ + contextSize: 4096, + batchSize: 512, + sequences: 1 + }); + const currentContextVramUsage = llama.getVramState().used; + + const contextVramUsageDiff = currentContextVramUsage - initialContextVramUsage; + + const estimatedContextVramUsage = ggufInsights.estimateContextResourceRequirements({ + contextSize: context.contextSize, + batchSize: context.batchSize, + sequences: context.totalSequences, + modelGpuLayers: ggufInsights.totalLayers + }).gpuVram; + expect(bytes(estimatedContextVramUsage)).toMatchInlineSnapshot('"809.83MB"'); + expect(Math.abs(contextVramUsageDiff - estimatedContextVramUsage)).to.be.lte(s100MB); await model.dispose(); }); @@ -77,11 +146,249 @@ describe("gguf", async () => { const vramUsageDiff = currentVramUsage - initialVramUsage; const s100MB = 100 * Math.pow(1024, 2); - const calculatedVramUsage = ggufInsights.calculateModelResourceRequirements(16).gpuVram; + const calculatedVramUsage = ggufInsights.estimateModelResourceRequirements({gpuLayers: 16}).gpuVram; expect(Math.abs(vramUsageDiff - calculatedVramUsage)).to.be.lte(s100MB); await model.dispose(); }); + + test("estimated context memory footprint stays the same", async () => { + const llama = await getTestLlama(); + const ggufMetadataParseResult = await readGgufFileInfo(modelPath); + + const ggufInsights = await GgufInsights.from(ggufMetadataParseResult, llama); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 8192, + modelGpuLayers: 0, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "1.52GB", + "gpuVram": "0B", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 4096, + modelGpuLayers: 0, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "809.83MB", + "gpuVram": "0B", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 2048, + modelGpuLayers: 0, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "436.2MB", + "gpuVram": "0B", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 1024, + modelGpuLayers: 0, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "249.39MB", + "gpuVram": "0B", + } + `); + + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 8192, + modelGpuLayers: 1, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "1GB", + "gpuVram": "565.1MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 4096, + modelGpuLayers: 1, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "512MB", + "gpuVram": "313.83MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 2048, + modelGpuLayers: 1, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "256MB", + "gpuVram": "188.2MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 1024, + modelGpuLayers: 1, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "128MB", + "gpuVram": "125.39MB", + } + `); + + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 8192, + modelGpuLayers: 16, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "544MB", + "gpuVram": "1.02GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 4096, + modelGpuLayers: 16, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "272MB", + "gpuVram": "553.83MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 2048, + modelGpuLayers: 16, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "136MB", + "gpuVram": "308.2MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 1024, + modelGpuLayers: 16, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "68MB", + "gpuVram": "185.39MB", + } + `); + + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 8192, + modelGpuLayers: 32, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "32MB", + "gpuVram": "1.52GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 4096, + modelGpuLayers: 32, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "16MB", + "gpuVram": "809.83MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 2048, + modelGpuLayers: 32, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "8MB", + "gpuVram": "436.2MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 1024, + modelGpuLayers: 32, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "4MB", + "gpuVram": "249.39MB", + } + `); + + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 8192, + modelGpuLayers: 33, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "1.52GB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 4096, + modelGpuLayers: 33, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "809.83MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 2048, + modelGpuLayers: 33, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "436.2MB", + } + `); + expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ + contextSize: 1024, + modelGpuLayers: 33, + sequences: 1, + batchSize: 512 + }))).toMatchInlineSnapshot(` + { + "cpuRam": "0B", + "gpuVram": "249.39MB", + } + `); + }); }); }); + +function makeEstimationReadable(resourceRequirements: GgufInsightsResourceRequirements) { + return { + cpuRam: bytes(resourceRequirements.cpuRam), + gpuVram: bytes(resourceRequirements.gpuVram) + }; +} From d30f06b2a67234725df2ee13eb4c9feb9c2c8617 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 01:29:01 +0200 Subject: [PATCH 23/52] feat: flexible default `gpuLayers` and `contextSize` options that depend on the currently available VRAM, token meter to track token usage --- .vitepress/config.ts | 2 +- src/bindings/Llama.ts | 28 +- src/bindings/getLlama.ts | 44 +- src/bindings/types.ts | 18 +- src/bindings/utils/MemoryOrchestrator.ts | 63 ++ src/bindings/utils/getLlamaWithoutBackend.ts | 3 +- ...esolveChatWrapperBasedOnWrapperTypeName.ts | 11 +- .../resolveChatWrapperBasedOnModel.ts | 13 +- .../inspect/commands/InspectGpuCommand.ts | 3 +- src/config.ts | 1 + src/evaluator/LlamaContext/LlamaContext.ts | 201 +++++- src/evaluator/LlamaContext/types.ts | 67 +- .../firstInFirstOutStrategy.ts | 0 .../maximumParallelismStrategy.ts | 0 ...esolveBatchItemsPrioritizationStrategy.ts} | 6 +- src/evaluator/LlamaEmbeddingContext.ts | 43 +- src/evaluator/LlamaModel.ts | 346 +++++++++- src/evaluator/TokenMeter.ts | 108 +++ src/gguf/GgufInsights.ts | 20 +- src/gguf/consts.ts | 2 +- src/gguf/fileReaders/GgufFsFileReader.ts | 23 +- .../fileReaders/GgufNetworkFetchFileReader.ts | 31 +- src/gguf/readGgufFileInfo.ts | 41 +- src/gguf/types/GgufFileInfoTypes.ts | 24 +- src/gguf/types/GgufMetadataTypes.ts | 364 +++++----- src/gguf/types/GgufTensorInfoTypes.ts | 8 +- src/index.ts | 15 +- src/utils/InsufficientMemoryError.ts | 5 + src/utils/findBestOption.ts | 21 + src/utils/parseModelTypeDescription.ts | 11 - src/utils/resolveChatWrapper.ts | 2 +- .../functionary/chatSession.test.ts | 97 +++ .../functionaryModelGpuLayersOptions.test.ts | 627 ++++++++++++++++++ .../stableCodeModelGpuLayersOptions.test.ts | 627 ++++++++++++++++++ 34 files changed, 2550 insertions(+), 325 deletions(-) create mode 100644 src/bindings/utils/MemoryOrchestrator.ts rename src/evaluator/LlamaContext/utils/{batchItemsPrioritizingStrategies => batchItemsPrioritizationStrategies}/firstInFirstOutStrategy.ts (100%) rename src/evaluator/LlamaContext/utils/{batchItemsPrioritizingStrategies => batchItemsPrioritizationStrategies}/maximumParallelismStrategy.ts (100%) rename src/evaluator/LlamaContext/utils/{resolveBatchItemsPrioritizingStrategy.ts => resolveBatchItemsPrioritizationStrategy.ts} (54%) create mode 100644 src/evaluator/TokenMeter.ts create mode 100644 src/utils/InsufficientMemoryError.ts create mode 100644 src/utils/findBestOption.ts delete mode 100644 src/utils/parseModelTypeDescription.ts create mode 100644 test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts create mode 100644 test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts diff --git a/.vitepress/config.ts b/.vitepress/config.ts index dd48d2dd..b6cff6ac 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -351,7 +351,7 @@ function orderTypes(sidebar: typeof typedocSidebar) { (item) => ( item.text === "BatchItem" || item.text === "CustomBatchingDispatchSchedule" || - item.text === "CustomBatchingPrioritizeStrategy" || + item.text === "CustomBatchingPrioritizationStrategy" || item.text === "PrioritizedBatchItem" ), {collapsed: false} diff --git a/src/bindings/Llama.ts b/src/bindings/Llama.ts index 42063801..20cd5ec2 100644 --- a/src/bindings/Llama.ts +++ b/src/bindings/Llama.ts @@ -5,6 +5,7 @@ import {LlamaModel, LlamaModelOptions} from "../evaluator/LlamaModel.js"; import {DisposeGuard} from "../utils/DisposeGuard.js"; import {BindingModule} from "./AddonTypes.js"; import {BuildGpu, BuildMetadataFile, LlamaLocks, LlamaLogLevel} from "./types.js"; +import {MemoryOrchestrator, MemoryReservation} from "./utils/MemoryOrchestrator.js"; const LlamaLogLevelToAddonLogLevel: ReadonlyMap = new Map([ [LlamaLogLevel.disabled, 0], @@ -24,6 +25,8 @@ export class Llama { /** @internal */ public readonly _backendDisposeGuard = new DisposeGuard(); /** @internal */ public readonly _memoryLock = {}; /** @internal */ public readonly _consts: ReturnType; + /** @internal */ public readonly _vramOrchestrator: MemoryOrchestrator; + /** @internal */ public readonly _vramPadding: MemoryReservation; /** @internal */ private readonly _gpu: BuildGpu; /** @internal */ private readonly _buildType: "localBuild" | "prebuilt"; /** @internal */ private readonly _cmakeOptions: Readonly>; @@ -47,7 +50,7 @@ export class Llama { public readonly onDispose = new EventRelay(); private constructor({ - bindings, logLevel, logger, buildType, cmakeOptions, llamaCppRelease + bindings, logLevel, logger, buildType, cmakeOptions, llamaCppRelease, vramPadding }: { bindings: BindingModule, logLevel: LlamaLogLevel, @@ -57,7 +60,8 @@ export class Llama { llamaCppRelease: { repo: string, release: string - } + }, + vramPadding: number | ((totalVram: number) => number) }) { this._bindings = bindings; this._gpu = bindings.getGpuType() ?? false; @@ -65,6 +69,20 @@ export class Llama { this._supportsMmap = bindings.getSupportsMmap(); this._supportsMlock = bindings.getSupportsMlock(); this._consts = bindings.getConsts(); + + this._vramOrchestrator = new MemoryOrchestrator(() => { + const {total, used} = bindings.getGpuVramInfo(); + + return { + total, + free: Math.max(0, total - used) + }; + }); + if (vramPadding instanceof Function) + this._vramPadding = this._vramOrchestrator.reserveMemory(vramPadding(this._vramOrchestrator.getMemoryState().total)); + else + this._vramPadding = this._vramOrchestrator.reserveMemory(vramPadding); + this._logLevel = logLevel ?? LlamaLogLevel.debug; this._logger = logger; this._buildType = buildType; @@ -283,13 +301,14 @@ export class Llama { /** @internal */ public static async _create({ - bindings, buildType, buildMetadata, logLevel, logger, skipLlamaInit = false + bindings, buildType, buildMetadata, logLevel, logger, vramPadding, skipLlamaInit = false }: { bindings: BindingModule, buildType: "localBuild" | "prebuilt", buildMetadata: BuildMetadataFile, logLevel: LlamaLogLevel, logger: (level: LlamaLogLevel, message: string) => void, + vramPadding: number | ((totalVram: number) => number), skipLlamaInit?: boolean }) { const llama = new Llama({ @@ -301,7 +320,8 @@ export class Llama { release: buildMetadata.buildOptions.llamaCpp.release }, logLevel, - logger + logger, + vramPadding }); if (!skipLlamaInit) diff --git a/src/bindings/getLlama.ts b/src/bindings/getLlama.ts index b577f986..37702bff 100644 --- a/src/bindings/getLlama.ts +++ b/src/bindings/getLlama.ts @@ -100,7 +100,16 @@ export type LlamaOptions = { * When set to `true`, and llama.cpp source is not found, a `NoBinaryFoundError` error will be thrown. * Disabled by default. */ - skipDownload?: boolean + skipDownload?: boolean, + + /** + * Pad the available VRAM for the memory size calculations, as these calculations are not always accurate. + * Recommended to ensure stability. + * + * Defaults to `1.5%` of the total VRAM or 300MB, whichever is lower. + * Set to `0` to disable. + */ + vramPadding?: number | ((totalVram: number) => number) }; export type LastBuildOptions = { @@ -133,11 +142,23 @@ export type LastBuildOptions = { * When set to `true`, and llama.cpp source is needed but is not found, a `NoBinaryFoundError` error will be thrown. * Disabled by default. */ - skipDownload?: boolean + skipDownload?: boolean, + + /** + * Pad the available VRAM for the memory size calculations, as these calculations are not always accurate. + * Recommended to ensure stability. + * This only affects the calculations of `"auto"` in function options and is not reflected in the `getVramState` function. + * + * Defaults to `1.5%` of the total VRAM or 300MB, whichever is lower. + * Set to `0` to disable. + */ + vramPadding?: number | ((totalVram: number) => number) }; export const getLlamaFunctionName = "getLlama"; +export const defaultLlamaVramPadding = (totalVram: number) => Math.floor(Math.min(totalVram * 0.015, 300 * 1024 * 1024)); + /** * Get a llama.cpp binding. * Defaults to prefer a prebuilt binary, and fallbacks to building from source if a prebuilt binary is not found. @@ -153,7 +174,8 @@ export async function getLlama(options?: LlamaOptions | "lastBuild", lastBuildOp logger: lastBuildOptions?.logger ?? Llama.defaultConsoleLogger, usePrebuiltBinaries: lastBuildOptions?.usePrebuiltBinaries ?? true, progressLogs: lastBuildOptions?.progressLogs ?? true, - skipDownload: lastBuildOptions?.skipDownload ?? defaultSkipDownload + skipDownload: lastBuildOptions?.skipDownload ?? defaultSkipDownload, + vramPadding: lastBuildOptions?.vramPadding ?? defaultLlamaVramPadding }; if (lastBuildInfo == null) @@ -173,7 +195,8 @@ export async function getLlama(options?: LlamaOptions | "lastBuild", lastBuildOp buildType: "localBuild", buildMetadata, logger: lastBuildOptions?.logger ?? Llama.defaultConsoleLogger, - logLevel: lastBuildOptions?.logLevel ?? defaultLlamaCppDebugLogs + logLevel: lastBuildOptions?.logLevel ?? defaultLlamaCppDebugLogs, + vramPadding: lastBuildOptions?.vramPadding ?? defaultLlamaVramPadding }); } catch (err) { console.error(getConsoleLogPrefix() + "Failed to load last build. Error:", err); @@ -196,7 +219,8 @@ export async function getLlamaForOptions({ existingPrebuiltBinaryMustMatchBuildOptions = false, usePrebuiltBinaries = true, progressLogs = true, - skipDownload = defaultSkipDownload + skipDownload = defaultSkipDownload, + vramPadding = defaultLlamaVramPadding }: LlamaOptions, { updateLastBuildInfoOnCompile = false, skipLlamaInit = false @@ -215,6 +239,7 @@ export async function getLlamaForOptions({ if (usePrebuiltBinaries == null) usePrebuiltBinaries = true; if (progressLogs == null) progressLogs = true; if (skipDownload == null) skipDownload = defaultSkipDownload; + if (vramPadding == null) vramPadding = defaultLlamaVramPadding; const clonedLlamaCppRepoReleaseInfo = await getClonedLlamaCppRepoReleaseInfo(); let canUsePrebuiltBinaries = (build === "forceRebuild" || !usePrebuiltBinaries) @@ -260,6 +285,7 @@ export async function getLlamaForOptions({ platform, platformInfo, skipLlamaInit, + vramPadding, fallbackMessage: !isLastItem ? `falling back to using ${getPrettyBuildGpuName(buildGpusToTry[i + 1])}` : ( @@ -315,6 +341,7 @@ export async function getLlamaForOptions({ logLevel, logger, updateLastBuildInfoOnCompile, + vramPadding, skipLlamaInit }); } catch (err) { @@ -348,6 +375,7 @@ async function loadExistingLlamaBinary({ platform, platformInfo, skipLlamaInit, + vramPadding, fallbackMessage }: { buildOptions: BuildOptions, @@ -359,6 +387,7 @@ async function loadExistingLlamaBinary({ platform: BinaryPlatform, platformInfo: BinaryPlatformInfo, skipLlamaInit: boolean, + vramPadding: Required["vramPadding"], fallbackMessage: string | null }) { const buildFolderName = await getBuildFolderNameForBuildOptions(buildOptions); @@ -389,6 +418,7 @@ async function loadExistingLlamaBinary({ buildMetadata, logLevel, logger, + vramPadding, skipLlamaInit }); } else if (progressLogs) { @@ -441,6 +471,7 @@ async function loadExistingLlamaBinary({ buildMetadata, logLevel, logger, + vramPadding, skipLlamaInit }); } else if (progressLogs) @@ -484,6 +515,7 @@ async function buildAndLoadLlamaBinary({ logLevel, logger, updateLastBuildInfoOnCompile, + vramPadding, skipLlamaInit }: { buildOptions: BuildOptions, @@ -491,6 +523,7 @@ async function buildAndLoadLlamaBinary({ logLevel: Required["logLevel"], logger: Required["logger"], updateLastBuildInfoOnCompile: boolean, + vramPadding: Required["vramPadding"], skipLlamaInit: boolean }) { const buildFolderName = await getBuildFolderNameForBuildOptions(buildOptions); @@ -519,6 +552,7 @@ async function buildAndLoadLlamaBinary({ buildMetadata, logLevel, logger, + vramPadding, skipLlamaInit }); } diff --git a/src/bindings/types.ts b/src/bindings/types.ts index e39b8a40..528806a0 100644 --- a/src/bindings/types.ts +++ b/src/bindings/types.ts @@ -69,14 +69,28 @@ export enum LlamaLogLevel { info = "info", debug = "debug" } -export const LlamaLogLevelValues = [ +export const LlamaLogLevelValues = Object.freeze([ LlamaLogLevel.disabled, LlamaLogLevel.fatal, LlamaLogLevel.error, LlamaLogLevel.warn, LlamaLogLevel.info, LlamaLogLevel.debug -] as const; +] as const); + +/** + *Check if a log level is higher than another log level + */ +export function LlamaLogLevelGreaterThan(a: LlamaLogLevel, b: LlamaLogLevel): boolean { + return LlamaLogLevelValues.indexOf(a) < LlamaLogLevelValues.indexOf(b); +} + +/** + *Check if a log level is higher than or equal to another log level + */ +export function LlamaLogLevelGreaterThanOrEqual(a: LlamaLogLevel, b: LlamaLogLevel): boolean { + return LlamaLogLevelValues.indexOf(a) <= LlamaLogLevelValues.indexOf(b); +} export const enum LlamaLocks { loadToMemory = "loadToMemory" diff --git a/src/bindings/utils/MemoryOrchestrator.ts b/src/bindings/utils/MemoryOrchestrator.ts new file mode 100644 index 00000000..e4dbc99d --- /dev/null +++ b/src/bindings/utils/MemoryOrchestrator.ts @@ -0,0 +1,63 @@ +import {EventRelay} from "lifecycle-utils"; + +export class MemoryOrchestrator { + /** @internal */ private readonly _getMemoryState: () => {free: number, total: number}; + /** @internal */ private _reservedMemory: number = 0; + + public readonly onMemoryReservationRelease = new EventRelay(); + + public constructor(getMemoryState: () => {free: number, total: number}) { + this._getMemoryState = getMemoryState; + } + + public reserveMemory(bytes: number) { + this._reservedMemory += bytes; + + return MemoryReservation._create(bytes, () => { + this._reservedMemory -= bytes; + this.onMemoryReservationRelease.dispatchEvent(); + }); + } + + public getMemoryState() { + const {free, total} = this._getMemoryState(); + + return { + free: Math.max(0, free - this._reservedMemory), + total + }; + } +} + +export class MemoryReservation { + /** @internal */ private readonly _size: number; + /** @internal */ private _dispose: (() => void) | null; + + private constructor(size: number, dispose: () => void) { + this._size = size; + this._dispose = dispose; + } + + public get size(): number { + return this._size; + } + + public get disposed(): boolean { + return this._dispose == null; + } + + public [Symbol.dispose](): void { + this.dispose(); + } + + public dispose(): void { + if (this._dispose != null) + this._dispose(); + + this._dispose = null; + } + + public static _create(bytes: number, dispose: () => void): MemoryReservation { + return new MemoryReservation(bytes, dispose); + } +} diff --git a/src/bindings/utils/getLlamaWithoutBackend.ts b/src/bindings/utils/getLlamaWithoutBackend.ts index f52c9589..6a64d59f 100644 --- a/src/bindings/utils/getLlamaWithoutBackend.ts +++ b/src/bindings/utils/getLlamaWithoutBackend.ts @@ -21,7 +21,8 @@ export async function getLlamaWithoutBackend() { progressLogs: false, logLevel: LlamaLogLevel.error, build: "never", - usePrebuiltBinaries: true + usePrebuiltBinaries: true, + vramPadding: 0 }, { skipLlamaInit: true }); diff --git a/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts b/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts index e0afd225..e24e7fd5 100644 --- a/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts +++ b/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts @@ -1,4 +1,3 @@ -import {ModelTypeDescription} from "../AddonTypes.js"; import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; import {LlamaChatWrapper} from "../../chatWrappers/LlamaChatWrapper.js"; import {AlpacaChatWrapper} from "../../chatWrappers/AlpacaChatWrapper.js"; @@ -7,6 +6,7 @@ import {ChatMLChatWrapper} from "../../chatWrappers/ChatMLChatWrapper.js"; import {FalconChatWrapper} from "../../chatWrappers/FalconChatWrapper.js"; import {resolveChatWrapperBasedOnModel} from "../../chatWrappers/resolveChatWrapperBasedOnModel.js"; import {GemmaChatWrapper} from "../../chatWrappers/GemmaChatWrapper.js"; +import {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; export const chatWrapperTypeNames = Object.freeze([ "auto", "general", "llamaChat", "alpacaChat", "functionary", "chatML", "falconChat", "gemma" @@ -33,15 +33,12 @@ const chatWrapperToConfigType = new Map( export function resolveChatWrapperBasedOnWrapperTypeName(configType: ChatWrapperTypeName, { bosString, filename, - typeDescription, + fileInfo, customWrapperSettings }: { bosString?: string | null, filename?: string, - - /** @hidden this type alias is too long in the documentation */ - typeDescription?: ModelTypeDescription, - + fileInfo?: GgufFileInfo, customWrapperSettings?: { [wrapper in keyof typeof chatWrappers]?: ConstructorParameters<(typeof chatWrappers)[wrapper]>[0] } @@ -58,7 +55,7 @@ export function resolveChatWrapperBasedOnWrapperTypeName(configType: ChatWrapper const chatWrapper = resolveChatWrapperBasedOnModel({ bosString, filename, - typeDescription + fileInfo }); if (chatWrapper != null) { diff --git a/src/chatWrappers/resolveChatWrapperBasedOnModel.ts b/src/chatWrappers/resolveChatWrapperBasedOnModel.ts index 2607c9b0..7aeb9773 100644 --- a/src/chatWrappers/resolveChatWrapperBasedOnModel.ts +++ b/src/chatWrappers/resolveChatWrapperBasedOnModel.ts @@ -1,6 +1,4 @@ import {parseModelFileName} from "../utils/parseModelFileName.js"; -import {parseModelTypeDescription} from "../utils/parseModelTypeDescription.js"; -import {ModelTypeDescription} from "../bindings/AddonTypes.js"; import {LlamaChatWrapper} from "./LlamaChatWrapper.js"; import {ChatMLChatWrapper} from "./ChatMLChatWrapper.js"; import {GeneralChatWrapper} from "./GeneralChatWrapper.js"; @@ -8,6 +6,7 @@ import {FalconChatWrapper} from "./FalconChatWrapper.js"; import {FunctionaryChatWrapper} from "./FunctionaryChatWrapper.js"; import {AlpacaChatWrapper} from "./AlpacaChatWrapper.js"; import {GemmaChatWrapper} from "./GemmaChatWrapper.js"; +import type {GgufFileInfo} from "../gguf/types/GgufFileInfoTypes.js"; /** @@ -16,13 +15,11 @@ import {GemmaChatWrapper} from "./GemmaChatWrapper.js"; export function resolveChatWrapperBasedOnModel({ bosString, filename, - typeDescription + fileInfo }: { bosString?: string | null, filename?: string, - - /** @hidden this type alias is too long in the documentation */ - typeDescription?: ModelTypeDescription + fileInfo?: GgufFileInfo }) { if (filename != null) { const {name, subType, fileType} = parseModelFileName(filename); @@ -59,8 +56,8 @@ export function resolveChatWrapperBasedOnModel({ } } - if (typeDescription != null) { - const {arch} = parseModelTypeDescription(typeDescription); + if (fileInfo != null) { + const arch = fileInfo.metadata.general?.architecture; if (arch === "llama") return LlamaChatWrapper; diff --git a/src/cli/commands/inspect/commands/InspectGpuCommand.ts b/src/cli/commands/inspect/commands/InspectGpuCommand.ts index 7c918fd8..2ce3ab63 100644 --- a/src/cli/commands/inspect/commands/InspectGpuCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGpuCommand.ts @@ -58,7 +58,8 @@ async function logGpuVramUsage(gpu: BuildGpu) { gpu: gpu, build: "never", progressLogs: false, - logLevel: LlamaLogLevel.warn + logLevel: LlamaLogLevel.warn, + vramPadding: 0 }, { skipLlamaInit: true }); diff --git a/src/config.ts b/src/config.ts index 56508e98..d2b8a0ad 100644 --- a/src/config.ts +++ b/src/config.ts @@ -90,3 +90,4 @@ export const documentationPageUrls = { Vulkan: documentationUrl + "/guide/vulkan" } as const; export const recommendedBaseDockerImage = "node:20"; +export const minAllowedContextSizeInCalculations = 24; diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index 9a05a32d..440e1c9c 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -5,11 +5,15 @@ import {BatchLogitIndex, AddonContext} from "../../bindings/AddonTypes.js"; import {LlamaGrammarEvaluationState} from "../LlamaGrammarEvaluationState.js"; import {compareTokens} from "../../utils/compareTokens.js"; import {DisposalPreventionHandle, DisposeGuard} from "../../utils/DisposeGuard.js"; +import {GgufInsights} from "../../gguf/GgufInsights.js"; +import {minAllowedContextSizeInCalculations} from "../../config.js"; +import {TokenMeter} from "../TokenMeter.js"; +import {BuildGpu} from "../../bindings/types.js"; import { BatchingOptions, BatchItem, ContextShiftOptions, ContextTokensDeleteRange, EvaluationPriority, LlamaContextOptions, LlamaContextSequenceRepeatPenalty, PrioritizedBatchItem } from "./types.js"; -import {resolveBatchItemsPrioritizingStrategy} from "./utils/resolveBatchItemsPrioritizingStrategy.js"; +import {resolveBatchItemsPrioritizationStrategy} from "./utils/resolveBatchItemsPrioritizationStrategy.js"; import type {Llama} from "../../bindings/Llama.js"; import type {LlamaModel} from "../LlamaModel.js"; @@ -44,18 +48,22 @@ export class LlamaContext { }: { _model: LlamaModel }, { - sequences = 1, + sequences, seed = null, - contextSize = _model.trainContextSize, - batchSize = Math.min(contextSize * sequences, 512), + contextSize, + batchSize, threads = 6, batching: { dispatchSchedule: batchingDispatchSchedule = "nextTick", - itemsPrioritizingStrategy: batchingItemsPrioritizingStrategy = "maximumParallelism" + itemPrioritizationStrategy: batchingItemsPrioritizationStrategy = "maximumParallelism" } = {}, _embeddings, _noSeed - }: LlamaContextOptions) { + }: LlamaContextOptions & { + sequences: number, + contextSize: number, + batchSize: number + }) { if (_model.disposed) throw new DisposedError(); @@ -77,7 +85,7 @@ export class LlamaContext { })); this._batchingOptions = { dispatchSchedule: batchingDispatchSchedule, - itemsPrioritizingStrategy: batchingItemsPrioritizingStrategy + itemPrioritizationStrategy: batchingItemsPrioritizationStrategy }; this._reclaimUnusedSequenceId = this._reclaimUnusedSequenceId.bind(this); @@ -166,9 +174,14 @@ export class LlamaContext { contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(this.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" - } = {} + } = {}, + + _tokenMeter }: { - contextShift?: ContextShiftOptions + contextShift?: ContextShiftOptions, + + /** @internal */ + _tokenMeter?: TokenMeter } = {}): LlamaContextSequence { this._ensureNotDisposed(); @@ -180,6 +193,7 @@ export class LlamaContext { return LlamaContextSequence._create({ sequenceId: nextSequenceId, context: this, + tokenMeter: _tokenMeter, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy @@ -203,10 +217,10 @@ export class LlamaContext { let shouldHaveAnotherLoop = this._queuedDecodes.length > 0; - const resolvePrioritizingStrategy = () => { + const resolvePrioritizationStrategy = () => { try { this._ensureNotDisposed(); - return resolveBatchItemsPrioritizingStrategy(this._batchingOptions.itemsPrioritizingStrategy); + return resolveBatchItemsPrioritizationStrategy(this._batchingOptions.itemPrioritizationStrategy); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set(this._queuedDecodes), err); } @@ -215,7 +229,7 @@ export class LlamaContext { }; const getOrderedQueuedDecodes = ( - prioritizeStrategy: ReturnType + prioritizationStrategy: ReturnType ): null | CurrentBatchItem[] => { const batchItemToQueuedDecodeMap = new Map(); const batchItemsList: BatchItem[] = []; @@ -231,7 +245,7 @@ export class LlamaContext { let prioritizedItems: PrioritizedBatchItem[]; try { - prioritizedItems = prioritizeStrategy({ + prioritizedItems = prioritizationStrategy({ items: batchItemsList, size: this._batchSize }); @@ -303,11 +317,20 @@ export class LlamaContext { for (const {queuedDecode, processAmount} of batchItems) { let batchLogitIndex: ReturnType; try { + const shouldGenerateLogitAtTheEnd = queuedDecode.generateLogitAtTheEnd && + processAmount === queuedDecode.tokens.length; + + const tokensToProcess = queuedDecode.tokens.slice(0, processAmount); + + const numberOfOutputTokens = shouldGenerateLogitAtTheEnd ? 1 : 0; + TokenMeter.useTokens(queuedDecode.tokenMeter, Math.max(0, tokensToProcess.length - numberOfOutputTokens), "input"); + TokenMeter.useTokens(queuedDecode.tokenMeter, numberOfOutputTokens, "output"); + batchLogitIndex = this._ctx.addToBatch( queuedDecode.sequenceId, queuedDecode.firstTokenSequenceIndex, - Uint32Array.from(queuedDecode.tokens.slice(0, processAmount)), - queuedDecode.generateLogitAtTheEnd && processAmount === queuedDecode.tokens.length + Uint32Array.from(tokensToProcess), + shouldGenerateLogitAtTheEnd ); } catch (err) { this._dispatchErrorForQueuedDecodesAndDequeue(new Set([queuedDecode]), err); @@ -359,11 +382,11 @@ export class LlamaContext { } }; - const prioritizeStrategy = resolvePrioritizingStrategy(); - if (prioritizeStrategy == null) return; // all queued items are rejected and dequeued when we get here + const prioritizationStrategy = resolvePrioritizationStrategy(); + if (prioritizationStrategy == null) return; // all queued items are rejected and dequeued when we get here while (shouldHaveAnotherLoop) { - const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizeStrategy); + const orderedQueuedDecodes = getOrderedQueuedDecodes(prioritizationStrategy); if (orderedQueuedDecodes == null) return; // all queued items are rejected and dequeued when we get here const { @@ -390,6 +413,11 @@ export class LlamaContext { }); } + /** + * Print the timings of token evaluation since that last print for this context. + * > **Note:** it prints on the `LlamaLogLevel.info` level, so if you set the level of your `Llama` instance higher than that, + * it won't print anything. + */ public async printTimings() { this._ensureNotDisposed(); @@ -399,10 +427,10 @@ export class LlamaContext { /** @internal */ public async _decodeTokens({ - sequenceId, firstTokenSequenceIndex, tokens, generateLogitAtTheEnd = false, evaluationPriority = 5 + sequenceId, firstTokenSequenceIndex, tokens, generateLogitAtTheEnd = false, evaluationPriority = 5, tokenMeter }: { sequenceId: number, firstTokenSequenceIndex: number, tokens: Token[], generateLogitAtTheEnd?: boolean, - evaluationPriority?: EvaluationPriority + evaluationPriority?: EvaluationPriority, tokenMeter: TokenMeter }, onDone?: ((batchLogitIndex: BatchLogitIndex) => (T | Promise))): Promise { return await new Promise((accept, reject) => { this._queuedDecodes.push({ @@ -411,6 +439,7 @@ export class LlamaContext { firstTokenSequenceIndex, generateLogitAtTheEnd, evaluationPriority, + tokenMeter, response: [accept, reject], onDone }); @@ -508,7 +537,22 @@ export class LlamaContext { public static async _create(options: LlamaContextOptions, {_model}: { _model: LlamaModel }): Promise { - const context = new LlamaContext({_model}, options); + const sequences = options.sequences ?? getDefaultContextSequences(); + const contextSize = resolveContextContextSizeOption({ + contextSize: options.contextSize, + batchSize: options.batchSize, + sequences: sequences, + modelFileInsights: _model.fileInsights, + modelGpuLayers: _model.gpuLayers, + modelTrainContextSize: _model.trainContextSize, + getVramState: () => _model._llama._vramOrchestrator.getMemoryState(), + llamaGpu: _model._llama.gpu, + ignoreMemorySafetyChecks: options.ignoreMemorySafetyChecks, + isEmbeddingContext: options._embeddings + }); + const batchSize = options.batchSize ?? getDefaultContextBatchSize({contextSize, sequences}); + + const context = new LlamaContext({_model}, {...options, contextSize, batchSize, sequences}); const {createSignal} = options; const contextLoaded = await context._ctx.init(); @@ -530,6 +574,7 @@ export class LlamaContextSequence { /** @internal */ private readonly _gcRegistry: FinalizationRegistry; /** @internal */ private readonly _context: LlamaContext; /** @internal */ private readonly _contextShift: Required; + /** @internal */ private readonly _tokenMeter: TokenMeter; /** @internal */ private readonly _disposeAggregator = new DisposeAggregator(); /** @internal */ private _contextTokens: Token[] = []; /** @internal */ private _nextTokenIndex: number = 0; @@ -538,14 +583,16 @@ export class LlamaContextSequence { public readonly onDispose = new EventRelay(); private constructor({ - sequenceId, context, contextShift + sequenceId, context, tokenMeter, contextShift }: { sequenceId: number, context: LlamaContext, + tokenMeter?: TokenMeter, contextShift: Required }) { this._sequenceId = sequenceId; this._context = context; + this._tokenMeter = tokenMeter ?? new TokenMeter(); this._contextShift = contextShift; this._gcRegistry = new FinalizationRegistry(this._context._reclaimUnusedSequenceId); @@ -600,6 +647,10 @@ export class LlamaContextSequence { return this._contextTokens.slice(); } + public get tokenMeter() { + return this._tokenMeter; + } + public get isLoadedToMemory() { return !this._disposed; } @@ -846,6 +897,7 @@ export class LlamaContextSequence { evalTokens, generateNewTokens, evaluationPriority, + this._tokenMeter, contextShiftOptions, (batchLogitIndex) => { const repeatPenaltyTokens = repeatPenalty?.punishTokens instanceof Function @@ -894,6 +946,7 @@ export class LlamaContextSequence { tokens: Token[], generateLogit: boolean, evaluationPriority: EvaluationPriority, + tokenMeter: TokenMeter, contextShiftOptions: Required, onDecodeDone: ((batchLogitIndex: BatchLogitIndex) => T | Promise) ): Promise { @@ -923,7 +976,8 @@ export class LlamaContextSequence { tokens: tokensToDecode, firstTokenSequenceIndex: this._nextTokenIndex, generateLogitAtTheEnd, - evaluationPriority + evaluationPriority, + tokenMeter }, !generateLogitAtTheEnd ? undefined : onDecodeDone @@ -990,7 +1044,7 @@ export class LlamaContextSequence { * @internal */ public static _create({ - sequenceId, context, + sequenceId, context, tokenMeter, contextShift: { size: contextShiftSize = Math.min(100, Math.ceil(context.contextSize / 2)), strategy: contextShiftStrategy = "eraseBeginning" @@ -998,11 +1052,13 @@ export class LlamaContextSequence { }: { sequenceId: number, context: LlamaContext, + tokenMeter?: TokenMeter, contextShift?: ContextShiftOptions }): LlamaContextSequence { return new LlamaContextSequence({ sequenceId, context, + tokenMeter, contextShift: { size: contextShiftSize, strategy: contextShiftStrategy @@ -1017,6 +1073,7 @@ type InternalQueuedDecode = { tokens: readonly Token[], generateLogitAtTheEnd: boolean, evaluationPriority: EvaluationPriority, + tokenMeter: TokenMeter, response: [accept: (res: any) => void, reject: (reason: unknown) => void], onDone?: (batchLogitIndex: BatchLogitIndex) => any }; @@ -1039,3 +1096,99 @@ function disposeContextSequenceIfReferenced(contextRef: WeakRef vramState.free) + throw new Error(`The context size of ${resolvedContextSize}${sequences > 1 ? ` with ${sequences} sequences` : ""} is too large for the available VRAM`); + + return resolvedContextSize; + } else if (contextSize === "auto" || typeof contextSize === "object") { + if (llamaGpu === false) + return modelTrainContextSize; + + const vramState = getVramState(); + + if (vramState.total === 0) + return modelTrainContextSize; + + const freeVram = vramState.free; + + const maxContextSize = contextSize === "auto" + ? getDefaultModelContextSize({trainContextSize: modelTrainContextSize}) + : Math.min( + contextSize.max ?? getDefaultModelContextSize({trainContextSize: modelTrainContextSize}), + getDefaultModelContextSize({trainContextSize: modelTrainContextSize}) + ); + + const minContextSize = contextSize === "auto" + ? minAllowedContextSizeInCalculations + : Math.max( + contextSize.min ?? minAllowedContextSizeInCalculations, + minAllowedContextSizeInCalculations + ); + + for (let testContextSize = maxContextSize; testContextSize >= minContextSize; testContextSize--) { + const contextVram = modelFileInsights.estimateContextResourceRequirements({ + contextSize: testContextSize, + batchSize: batchSize ?? getDefaultContextBatchSize({contextSize: testContextSize, sequences}), + modelGpuLayers: modelGpuLayers, + sequences, + isEmbeddingContext + }).gpuVram; + + if (contextVram <= freeVram) + return testContextSize; + } + + if (ignoreMemorySafetyChecks) + return minContextSize; + + throw new Error(`The available VRAM is too small to fit the context size of ${maxContextSize}${sequences > 1 ? ` with ${sequences} sequences` : ""}`); + } + + throw new Error(`Invalid context size: "${contextSize}"`); +} + +export function getDefaultContextBatchSize({contextSize, sequences}: {contextSize: number, sequences: number}) { + return Math.min(contextSize * sequences, 512); +} +export function getDefaultContextSequences() { + return 1; +} + +const defaultFallbackContextSize = 4096; +export function getDefaultModelContextSize({trainContextSize}: {trainContextSize?: number}) { + return trainContextSize ?? defaultFallbackContextSize; +} diff --git a/src/evaluator/LlamaContext/types.ts b/src/evaluator/LlamaContext/types.ts index 0c42d8e4..4e01516e 100644 --- a/src/evaluator/LlamaContext/types.ts +++ b/src/evaluator/LlamaContext/types.ts @@ -8,6 +8,9 @@ export type LlamaContextOptions = { * Each sequence is a different "text generation process" that can run in parallel to other sequences in the same context. * Although a single context has multiple sequences, the sequences are separate from each other and do not share data with each other. * This is beneficial for performance, as multiple sequences can be evaluated in parallel (on the same batch). + * + * Each sequence increases the memory usage of the context. + * Defaults to `1`. */ sequences?: number, @@ -15,10 +18,21 @@ export type LlamaContextOptions = { seed?: number | null, /** - * The number of tokens can the model see at once. - * Defaults to the context size the model was trained on. + * The number of tokens the model can see at once. + * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size + * the model was trained on. + * - **`number`** - set the context size to a specific number of tokens. + * If there's not enough VRAM, an error will be thrown. + * Use with caution. + * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible + * up to the size the model was trained on, but at least `min` and at most `max`. + * + * Defaults to `"auto"`. */ - contextSize?: number, + contextSize?: "auto" | number | { + min?: number, + max?: number + }, /** * The number of tokens that can be processed at once by the GPU. @@ -28,7 +42,9 @@ export type LlamaContextOptions = { /** * number of threads to use to evaluate tokens. - * set to 0 to use the maximum threads supported by the current machine hardware + * set to 0 to use the maximum threads supported by the current machine hardware. + * + * Defaults to `6`. */ threads?: number, @@ -38,6 +54,14 @@ export type LlamaContextOptions = { /** An abort signal to abort the context creation */ createSignal?: AbortSignal, + /** + * Ignore insufficient memory errors and continue with the context creation. + * Can cause the process to crash if there's not enough VRAM for the new context. + * + * Defaults to `false`. + */ + ignoreMemorySafetyChecks?: boolean, + /** * embedding mode only * @internal @@ -77,11 +101,42 @@ export type LlamaContextSequenceRepeatPenalty = { }; export type BatchingOptions = { + /** + * The strategy used to dispatch items to be processed when there are items pending to be processed. + * - **`"nextTick"`** - dispatch the items on the next even loop tick. + * You can provide a custom function to define a custom dispatch schedule. + * + * Defaults to `"nextTick"`. + */ dispatchSchedule?: "nextTick" | CustomBatchingDispatchSchedule, - itemsPrioritizingStrategy?: "maximumParallelism" | "firstInFirstOut" | CustomBatchingPrioritizeStrategy + + /** + * The strategy used to prioritize pending items to be processed. + * - **`"maximumParallelism"`** - process as many different sequences in parallel as possible. + * - **`"firstInFirstOut"`** - process items in the order they were added. + * - **Custom prioritization function** - a custom function that prioritizes the items to be processed. + * See the `CustomBatchingPrioritizationStrategy` type for more information. + * + * Defaults to `"maximumParallelism"`. + */ + itemPrioritizationStrategy?: "maximumParallelism" | "firstInFirstOut" | CustomBatchingPrioritizationStrategy }; + +/** + * A function that schedules the dispatch of the batch items. + * Call the `dispatch` function to dispatch the items. + */ export type CustomBatchingDispatchSchedule = (dispatch: () => void) => void; -export type CustomBatchingPrioritizeStrategy = (options: { + +/** + * A function that prioritizes the batch items to be processed. + * The function receives an array of `items` and the `size` of how many tokens can be processed in this batch. + * + * The function should return an array of prioritized items, + * where the sum of `processAmount` of all the items is less or equal to the given `size` that the function received, + * and where the `item` of each prioritized item is the same reference to an original item in the `items` array. + */ +export type CustomBatchingPrioritizationStrategy = (options: { items: readonly BatchItem[], size: number }) => PrioritizedBatchItem[]; diff --git a/src/evaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.ts b/src/evaluator/LlamaContext/utils/batchItemsPrioritizationStrategies/firstInFirstOutStrategy.ts similarity index 100% rename from src/evaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/firstInFirstOutStrategy.ts rename to src/evaluator/LlamaContext/utils/batchItemsPrioritizationStrategies/firstInFirstOutStrategy.ts diff --git a/src/evaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.ts b/src/evaluator/LlamaContext/utils/batchItemsPrioritizationStrategies/maximumParallelismStrategy.ts similarity index 100% rename from src/evaluator/LlamaContext/utils/batchItemsPrioritizingStrategies/maximumParallelismStrategy.ts rename to src/evaluator/LlamaContext/utils/batchItemsPrioritizationStrategies/maximumParallelismStrategy.ts diff --git a/src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.ts b/src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizationStrategy.ts similarity index 54% rename from src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.ts rename to src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizationStrategy.ts index bf7fe25b..fcdb5763 100644 --- a/src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizingStrategy.ts +++ b/src/evaluator/LlamaContext/utils/resolveBatchItemsPrioritizationStrategy.ts @@ -1,8 +1,8 @@ import {BatchingOptions} from "../types.js"; -import {maximumParallelismStrategy} from "./batchItemsPrioritizingStrategies/maximumParallelismStrategy.js"; -import {firstInFirstOutStrategy} from "./batchItemsPrioritizingStrategies/firstInFirstOutStrategy.js"; +import {maximumParallelismStrategy} from "./batchItemsPrioritizationStrategies/maximumParallelismStrategy.js"; +import {firstInFirstOutStrategy} from "./batchItemsPrioritizationStrategies/firstInFirstOutStrategy.js"; -export function resolveBatchItemsPrioritizingStrategy(strategy: Required["itemsPrioritizingStrategy"]) { +export function resolveBatchItemsPrioritizationStrategy(strategy: Required["itemPrioritizationStrategy"]) { if (strategy instanceof Function) return strategy; else if (strategy === "maximumParallelism") diff --git a/src/evaluator/LlamaEmbeddingContext.ts b/src/evaluator/LlamaEmbeddingContext.ts index 4cf4a77c..1b001ecc 100644 --- a/src/evaluator/LlamaEmbeddingContext.ts +++ b/src/evaluator/LlamaEmbeddingContext.ts @@ -6,8 +6,22 @@ import type {LlamaModel} from "./LlamaModel.js"; import type {LlamaContext, LlamaContextSequence} from "./LlamaContext/LlamaContext.js"; export type LlamaEmbeddingContextOptions = { - /** text context size */ - contextSize?: number, + /** + * The number of tokens the model can see at once. + * - **`"auto"`** - adapt to the current VRAM state and attemp to set the context size as high as possible up to the size + * the model was trained on. + * - **`number`** - set the context size to a specific number of tokens. + * If there's not enough VRAM, an error will be thrown. + * Use with caution. + * - **`{min?: number, max?: number}`** - adapt to the current VRAM state and attemp to set the context size as high as possible + * up to the size the model was trained on, but at least `min` and at most `max`. + * + * Defaults to `"auto"`. + */ + contextSize?: "auto" | number | { + min?: number, + max?: number + }, /** prompt processing batch size */ batchSize?: number, @@ -19,7 +33,15 @@ export type LlamaEmbeddingContextOptions = { threads?: number, /** An abort signal to abort the context creation */ - createSignal?: AbortSignal + createSignal?: AbortSignal, + + /** + * Ignore insufficient memory errors and continue with the context creation. + * Can cause the process to crash if there's not enough VRAM for the new context. + * + * Defaults to `false`. + */ + ignoreMemorySafetyChecks?: boolean }; export class LlamaEmbeddingContext { @@ -97,19 +119,18 @@ export class LlamaEmbeddingContext { }: { _model: LlamaModel }, { - contextSize = _model.trainContextSize, - batchSize = contextSize, + contextSize, + batchSize, threads = 6, - createSignal + createSignal, + ignoreMemorySafetyChecks }: LlamaEmbeddingContextOptions) { - const resolvedContextSize = Math.min(contextSize, _model.trainContextSize); - const resolvedBatchSize = Math.min(batchSize, resolvedContextSize); - const llamaContext = await _model.createContext({ - contextSize: resolvedContextSize, - batchSize: resolvedBatchSize, + contextSize, + batchSize, threads, createSignal, + ignoreMemorySafetyChecks, _embeddings: true, _noSeed: true }); diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index b6cd8597..dd93b116 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -5,20 +5,44 @@ import {removeNullFields} from "../utils/removeNullFields.js"; import {Token} from "../types.js"; import {ModelTypeDescription, AddonModel} from "../bindings/AddonTypes.js"; import {DisposalPreventionHandle, DisposeGuard} from "../utils/DisposeGuard.js"; -import {LlamaLocks} from "../bindings/types.js"; +import {BuildGpu, LlamaLocks} from "../bindings/types.js"; +import {GgufFileInfo} from "../gguf/types/GgufFileInfoTypes.js"; +import {readGgufFileInfo} from "../gguf/readGgufFileInfo.js"; +import {GgufInsights} from "../gguf/GgufInsights.js"; +import {findBestOption} from "../utils/findBestOption.js"; +import {InsufficientMemoryError} from "../utils/InsufficientMemoryError.js"; +import {minAllowedContextSizeInCalculations} from "../config.js"; import {LlamaContextOptions} from "./LlamaContext/types.js"; -import {LlamaContext} from "./LlamaContext/LlamaContext.js"; +import {getDefaultContextBatchSize, getDefaultModelContextSize, LlamaContext} from "./LlamaContext/LlamaContext.js"; import {LlamaEmbeddingContext, LlamaEmbeddingContextOptions} from "./LlamaEmbeddingContext.js"; import type {Llama} from "../bindings/Llama.js"; import type {BuiltinSpecialTokenValue} from "../utils/LlamaText.js"; +const fitContextExtraMemoryPaddingPercentage = 0.5; export type LlamaModelOptions = { /** path to the model on the filesystem */ modelPath: string, - /** number of layers to store in VRAM */ - gpuLayers?: number, + /** + * Number of layers to store in VRAM. + * - **`"auto"`** - adapt to the current VRAM state and try to fit as many layers as possible in it. + * Takes into account the VRAM required to create a context with a `contextSize` set to `"auto"`. + * - **`"max"`** - store all layers in VRAM. If there's not enough VRAM, an error will be thrown. Use with caution. + * - **`number`** - store the specified number of layers in VRAM. If there's not enough VRAM, an error will be thrown. Use with caution. + * - **`{min?: number, max?: number, fitContext?: {contextSize: number}}`** - adapt to the current VRAM state and try to fit as + * many layers as possible in it, but at least `min` and at most `max` layers. Set `fitContext` to the parameters of a context you + * intend to create with the model, so it'll take it into account in the calculations and leave enough memory for such a context. + * + * If GPU support is disabled, will be set to `0` automatically. + * + * Defaults to `"auto"`. + */ + gpuLayers?: "auto" | "max" | number | { + min?: number, + max?: number, + fitContext?: {contextSize: number} + }, /** only load the vocabulary, no weights */ vocabOnly?: boolean, @@ -42,7 +66,15 @@ export type LlamaModelOptions = { onLoadProgress?(loadProgress: number): void, /** An abort signal to abort the model load */ - loadSignal?: AbortSignal + loadSignal?: AbortSignal, + + /** + * Ignore insufficient memory errors and continue with the model load. + * Can cause the process to crash if there's not enough VRAM to fit the model. + * + * Defaults to `false`. + */ + ignoreMemorySafetyChecks?: boolean }; export class LlamaModel { @@ -50,6 +82,9 @@ export class LlamaModel { /** @internal */ public readonly _model: AddonModel; /** @internal */ public readonly _backendModelDisposeGuard: DisposeGuard; /** @internal */ private readonly _tokens: LlamaModelTokens; + /** @internal */ private readonly _fileInfo: GgufFileInfo; + /** @internal */ private readonly _fileInsights: GgufInsights; + /** @internal */ private readonly _gpuLayers: number; /** @internal */ private readonly _filename?: string; /** @internal */ private readonly _disposedState: DisposedState = {disposed: false}; /** @internal */ private readonly _disposeAggregator = new AsyncDisposeAggregator(); @@ -62,12 +97,21 @@ export class LlamaModel { private constructor({ modelPath, gpuLayers, vocabOnly, useMmap, useMlock, onLoadProgress, loadSignal - }: LlamaModelOptions, { - _llama + }: LlamaModelOptions & { + gpuLayers: number + }, { + _llama, + _fileInfo, + _fileInsights }: { - _llama: Llama + _llama: Llama, + _fileInfo: GgufFileInfo, + _fileInsights: GgufInsights }) { this._llama = _llama; + this._fileInfo = _fileInfo; + this._fileInsights = _fileInsights; + this._gpuLayers = gpuLayers; this._backendModelDisposeGuard = new DisposeGuard([this._llama._backendDisposeGuard]); this._llamaPreventDisposalHandle = this._llama._backendDisposeGuard.createPreventDisposalHandle(); this._model = new this._llama._bindings.AddonModel(path.resolve(process.cwd(), modelPath), removeNullFields({ @@ -139,6 +183,22 @@ export class LlamaModel { return this._filename; } + public get fileInfo(): GgufFileInfo { + return this._fileInfo; + } + + public get fileInsights(): GgufInsights { + return this._fileInsights; + } + + /** + * Number of layers offloaded to the GPU. + * If GPU support is disabled, this will always be `0`. + */ + public get gpuLayers(): number { + return this._gpuLayers; + } + /** * Total model size in memory in bytes */ @@ -253,8 +313,21 @@ export class LlamaModel { }: { _llama: Llama }) { - const model = new LlamaModel(modelOptions, {_llama}); const {loadSignal} = modelOptions; + const fileInfo = await readGgufFileInfo(modelOptions.modelPath, { + sourceType: "filesystem", + signal: loadSignal + }); + const ggufInsights = await GgufInsights.from(fileInfo, _llama); + const gpuLayers = resolveModelGpuLayersOption(modelOptions.gpuLayers, { + ggufInsights, + ignoreMemorySafetyChecks: modelOptions.ignoreMemorySafetyChecks, + getVramState: () => _llama._vramOrchestrator.getMemoryState(), + llamaVramPaddingSize: _llama._vramPadding.size, + llamaGpu: _llama.gpu, + llamaSupportsGpuOffloading: _llama.supportsGpuOffloading + }); + const model = new LlamaModel({...modelOptions, gpuLayers}, {_fileInfo: fileInfo, _fileInsights: ggufInsights, _llama}); function onAbort() { model._model.abortActiveModelLoad(); @@ -615,6 +688,261 @@ function disposeModelIfReferenced(modelRef: WeakRef) { void model.dispose(); } +export function resolveModelGpuLayersOption(gpuLayers: LlamaModelOptions["gpuLayers"], { + ggufInsights, isEmbeddingContext = false, ignoreMemorySafetyChecks = false, getVramState, llamaVramPaddingSize, + llamaGpu, llamaSupportsGpuOffloading +}: { + ggufInsights: GgufInsights, isEmbeddingContext?: boolean, ignoreMemorySafetyChecks?: boolean, + getVramState(): {total: number, free: number}, llamaVramPaddingSize: number, llamaGpu: BuildGpu, llamaSupportsGpuOffloading: boolean +}): number { + if (gpuLayers == null) + gpuLayers = "auto"; + + if (!llamaSupportsGpuOffloading) + return 0; + + if (gpuLayers === "max" || typeof gpuLayers === "number") { + const resolvedGpuLayers = typeof gpuLayers === "number" + ? Math.max(0, Math.min(ggufInsights.totalLayers, gpuLayers)) + : ggufInsights.totalLayers; + + if (ignoreMemorySafetyChecks) + return resolvedGpuLayers; + + const vramState = getVramState(); + const maxLayersRequirements = getVramRequiredForGpuLayers({ + gpuLayers: resolvedGpuLayers, + ggufInsights, + currentVram: vramState.free, + isEmbeddingContext + }); + + if (maxLayersRequirements == null) + throw new InsufficientMemoryError("Not enough VRAM to fit the model with the specified settings"); + + return resolvedGpuLayers; + } else if (gpuLayers === "auto" || typeof gpuLayers === "object") { + if (llamaGpu === false) + return 0; + + const vramState = getVramState(); + if (vramState.total === 0) + return 0; + + let freeVram = vramState.free; + if (typeof gpuLayers === "object" && gpuLayers.fitContext != null) { + freeVram -= llamaVramPaddingSize * fitContextExtraMemoryPaddingPercentage; + + if (freeVram < 0) + freeVram = 0; + } + + const bestGpuLayersOption = getBestGpuLayersForFreeVram({ + ggufInsights, + freeVram, + fitContext: typeof gpuLayers === "object" + ? gpuLayers.fitContext + : undefined, + minGpuLayers: typeof gpuLayers === "object" + ? gpuLayers.min + : undefined, + maxGpuLayers: typeof gpuLayers === "object" + ? gpuLayers.max + : undefined, + isEmbeddingContext + }); + + const hasGpuLayersRequirements = typeof gpuLayers === "object" && + (gpuLayers.min != null || gpuLayers.max != null || gpuLayers.fitContext?.contextSize != null); + + if (!ignoreMemorySafetyChecks && bestGpuLayersOption == null && hasGpuLayersRequirements) + throw new InsufficientMemoryError("Not enough VRAM to fit the model with the specified settings"); + + return bestGpuLayersOption ?? 0; + } + + throw new Error(`Invalid gpuLayers value: ${gpuLayers}`); +} + +function getBestGpuLayersForFreeVram({ + ggufInsights, + freeVram, + fitContext, + minGpuLayers, + maxGpuLayers, + isEmbeddingContext = false +}: { + ggufInsights: GgufInsights, + freeVram: number, + fitContext?: {contextSize: number}, + minGpuLayers?: number, + maxGpuLayers?: number, + isEmbeddingContext?: boolean +}) { + return findBestOption({ + *generator() { + const minLayers = Math.floor(Math.max(0, minGpuLayers ?? 0)); + const maxLayers = Math.floor(Math.min(ggufInsights.totalLayers, maxGpuLayers ?? ggufInsights.totalLayers)); + + for (let layers = maxLayers; layers >= minLayers; layers--) { + yield { + gpuLayers: layers + }; + } + }, + score(option) { + const layersRequirements = getVramRequiredForGpuLayers({ + gpuLayers: option.gpuLayers, + ggufInsights, + currentVram: freeVram, + fitContext, + isEmbeddingContext + }); + + if (layersRequirements == null) + return null; + + return scoreGpuLayersAndContextCombination({gpuLayers: option.gpuLayers, contextSize: layersRequirements.contextSize}, { + totalGpuLayers: ggufInsights.totalLayers, + trainContextSize: getDefaultModelContextSize({trainContextSize: ggufInsights.trainContextSize}) + }); + } + })?.gpuLayers ?? null; +} + +function scoreGpuLayersAndContextCombination({gpuLayers, contextSize}: {gpuLayers: number, contextSize: number}, { + totalGpuLayers, trainContextSize +}: { + totalGpuLayers: number, trainContextSize: number +}) { + function scoreGpuLayers() { + return scoreLevels(gpuLayers, [{ + start: 0, + points: 4 + }, { + start: 1, + points: 26 + }, { + start: totalGpuLayers, + points: 14, + end: totalGpuLayers + }]); + } + + function scoreContextSize() { + const gpuLayersPercentage = gpuLayers / totalGpuLayers; + + return scoreLevels(contextSize, [{ + start: 0, + points: 2 + }, { + start: 1024, + points: 4 + }, { + start: 2048, + points: gpuLayersPercentage < 0.1 ? 1 : 8 + }, { + start: 4096, + points: gpuLayersPercentage < 0.3 ? 4 : 16 + }, { + start: 8192, + points: gpuLayersPercentage < 0.6 ? 1 : 8, + end: Math.max(trainContextSize, 16384) + }]); + } + + return scoreGpuLayers() + scoreContextSize(); +} + +function scoreLevels(num: number, levels: {start: number, end?: number, points: number}[]) { + let res = 0; + + for (let i = 0; i < levels.length; i++) { + const level = levels[i]; + const start = level.start; + const end = level.end ?? levels[i + 1]?.start ?? Math.max(start, num); + + if (num < start) + break; + else if (num >= end) + res += level.points; + else + res += level.points * ((num - start) / (end - start)); + } + + return res; +} + +function getVramRequiredForGpuLayers({ + gpuLayers, ggufInsights, currentVram, fitContext, isEmbeddingContext +}: { + gpuLayers: number, ggufInsights: GgufInsights, currentVram: number, fitContext?: {contextSize: number}, isEmbeddingContext: boolean +}) { + const modelVram = ggufInsights.estimateModelResourceRequirements({gpuLayers}).gpuVram; + + if (modelVram > currentVram) + return null; + + if (fitContext != null) { + const contextVram = ggufInsights.estimateContextResourceRequirements({ + contextSize: fitContext.contextSize, + batchSize: getDefaultContextBatchSize({contextSize: fitContext.contextSize, sequences: 1}), + modelGpuLayers: gpuLayers, + sequences: 1, + isEmbeddingContext + }).gpuVram; + + const totalVram = modelVram + contextVram; + if (totalVram > currentVram) + return null; + + return { + contextSize: fitContext.contextSize, + contextVram, + totalVram + }; + } + + const maxContext = findMaxPossibleContextSizeForVram({ + gpuLayers, + ggufInsights, + vram: currentVram - modelVram, + isEmbeddingContext + }); + + if (maxContext == null || modelVram + maxContext.vram > currentVram) + return null; + + return { + contextSize: maxContext.contextSize, + contextVram: maxContext.vram, + totalVram: modelVram + maxContext.vram + }; +} + +function findMaxPossibleContextSizeForVram({gpuLayers, ggufInsights, vram, isEmbeddingContext}: { + gpuLayers: number, ggufInsights: GgufInsights, vram: number, isEmbeddingContext: boolean +}) { + const maxContextSize = getDefaultModelContextSize({trainContextSize: ggufInsights.trainContextSize}); + + for (let contextSize = maxContextSize; contextSize >= minAllowedContextSizeInCalculations; contextSize--) { + const contextVram = ggufInsights.estimateContextResourceRequirements({ + contextSize, + batchSize: getDefaultContextBatchSize({contextSize, sequences: 1}), + modelGpuLayers: gpuLayers, + sequences: 1, + isEmbeddingContext + }).gpuVram; + + if (contextVram <= vram) + return { + contextSize, + vram: contextVram + }; + } + + return null; +} type DisposedState = { disposed: boolean diff --git a/src/evaluator/TokenMeter.ts b/src/evaluator/TokenMeter.ts new file mode 100644 index 00000000..0fae669f --- /dev/null +++ b/src/evaluator/TokenMeter.ts @@ -0,0 +1,108 @@ +/** + * Tracks the evaluation usage of tokens. + */ +export class TokenMeter { + private _inputTokens: number = 0; + private _outputTokens: number = 0; + private _restoreStateTokens: number = 0; + + /** + * The number of input tokens used + */ + public get usedInputTokens() { + return this._inputTokens; + } + + /** + * The number of tokens generated by a model + */ + public get usedOutputTokens() { + return this._outputTokens; + } + + /** + * The number of tokens used as input to restore a context sequence state to continue previous evaluation. + * This may be consumed by virtual context sequences. + */ + public get usedRestoreStateTokens() { + return this._restoreStateTokens; + } + + /** + * Get the current state of the token meter + */ + public getState(): TokenMeterState { + return { + usedInputTokens: this.usedInputTokens, + usedOutputTokens: this.usedOutputTokens, + usedRestoreStateTokens: this.usedRestoreStateTokens + }; + } + + /** + * Log the usage of tokens + */ + public useTokens(tokens: number, type: "input" | "output" | "restoreState") { + if (tokens < 0) + throw new RangeError("Tokens cannot be negative"); + else if (tokens === 0) + return; + + if (type === "input") + this._inputTokens += tokens; + else if (type === "output") + this._outputTokens += tokens; + else if (type === "restoreState") + this._restoreStateTokens += tokens; + else { + void (type satisfies never); + throw new TypeError(`Unknown token type: ${type}`); + } + } + + /** + * Get the difference between the current meter and another meter + */ + public diff(meter: TokenMeter | TokenMeterState) { + return TokenMeter.diff(this, meter); + } + + /** + * Log the usage of tokens on multiple meters + */ + public static useTokens( + meters: null | undefined | TokenMeter | readonly TokenMeter[] | ReadonlySet, + tokens: number, + type: "input" | "output" | "restoreState" + ) { + if (meters == null) + return; + + if (meters instanceof TokenMeter) + meters.useTokens(tokens, type); + else { + for (const meter of meters) + meter.useTokens(tokens, type); + } + } + + /** + * Get the difference between two meters + */ + public static diff( + meter1: TokenMeter | TokenMeterState, + meter2: TokenMeter | TokenMeterState + ) { + return { + usedInputTokens: meter1.usedInputTokens - meter2.usedInputTokens, + usedOutputTokens: meter1.usedOutputTokens - meter2.usedOutputTokens, + usedRestoreStateTokens: meter1.usedRestoreStateTokens - meter2.usedRestoreStateTokens + }; + } +} + +export type TokenMeterState = { + usedInputTokens: number, + usedOutputTokens: number, + usedRestoreStateTokens: number +}; diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 05fb55cc..17d74dbc 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -22,6 +22,16 @@ export class GgufInsights { this._modelSize = calculateTensorsSize(ggufFileInfo.tensorInfo ?? [], llama); } + /** The context size the model was trained on */ + public get trainContextSize() { + return this.ggufFileInfo.architectureMetadata.context_length; + } + + /** The size of an embedding vector the model can produce */ + public get embeddingVectorSize() { + return this.ggufFileInfo.architectureMetadata.embedding_length; + } + public get totalLayers() { if (this._totalLayers != null) return this._totalLayers; @@ -56,6 +66,8 @@ export class GgufInsights { contextSize: number, batchSize: number, modelGpuLayers: number, sequences: number, isEmbeddingContext?: boolean, includeGraphOverhead?: boolean }): GgufInsightsResourceRequirements { + const actualContextSize = contextSize * sequences; + const totalLayers = this.totalLayers; const finalGpuLayers = Math.max(0, Math.min(modelGpuLayers ?? totalLayers, totalLayers)); const finalCpuLayers = totalLayers - finalGpuLayers; @@ -86,7 +98,7 @@ export class GgufInsights { const sKvCell = this._llama._consts.llamaPosSize + sizeTBytes + this._llama._consts.llamaSeqIdSize; const kvSelfLength = this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.mamba ? Math.max(1, sequences) - : contextSize; + : actualContextSize; const sKvCells = kvSelfLength * sKvCell; const overheadMemory = ( @@ -120,7 +132,7 @@ export class GgufInsights { }, 0); // magic numbers for estimation. will be improved in the future - return totalDimensions * 77.655 * (contextSize / 4096); + return totalDimensions * 77.655 * (actualContextSize / 4096); }; const graphOverheadMemory = !includeGraphOverhead @@ -134,13 +146,13 @@ export class GgufInsights { ? (overheadMemory + graphOverheadMemory) : 0 ) + - this._estimateKvMemorySizeInBytes(contextSize, finalCpuLayers); + this._estimateKvMemorySizeInBytes(actualContextSize, finalCpuLayers); const gpuVram = usingGpu ? ( overheadMemory + graphOverheadMemory + this._estimateKvMemorySizeInBytes( - contextSize, + actualContextSize, finalGpuLayers < totalLayers ? (finalGpuLayers + 1) : finalGpuLayers diff --git a/src/gguf/consts.ts b/src/gguf/consts.ts index 9140b0b3..48f0beb4 100644 --- a/src/gguf/consts.ts +++ b/src/gguf/consts.ts @@ -1,6 +1,6 @@ import retry from "async-retry"; -export const ggufDefaultRetryOptions: retry.Options = { +export const ggufDefaultFetchRetryOptions: retry.Options = { retries: 10, factor: 2, minTimeout: 1000, diff --git a/src/gguf/fileReaders/GgufFsFileReader.ts b/src/gguf/fileReaders/GgufFsFileReader.ts index 5f5113a1..83f58f26 100644 --- a/src/gguf/fileReaders/GgufFsFileReader.ts +++ b/src/gguf/fileReaders/GgufFsFileReader.ts @@ -1,23 +1,22 @@ import fs from "node:fs/promises"; -import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../consts.js"; +import {defaultExtraAllocationSize} from "../consts.js"; import {GgufFileReader} from "./GgufFileReader.js"; type GgufFsFileReaderOptions = { filePath: string, - retryOptions?: retry.Options + signal?: AbortSignal }; export class GgufFsFileReader extends GgufFileReader { public readonly filePath: string; - public readonly retryOptions: retry.Options; + private readonly _signal?: AbortSignal; - public constructor({filePath, retryOptions = ggufDefaultRetryOptions}: GgufFsFileReaderOptions) { + public constructor({filePath, signal}: GgufFsFileReaderOptions) { super(); this.filePath = filePath; - this.retryOptions = retryOptions; + this._signal = signal; } public async readByteRange(offset: number | GgufReadOffset, length: number) { @@ -33,13 +32,14 @@ export class GgufFsFileReader extends GgufFileReader { } private async _readToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = defaultExtraAllocationSize) { - return await withLock(this, "modifyBuffer", async () => { + return await withLock(this, "modifyBuffer", this._signal, async () => { if (endOffset < this._buffer.length) return; - const missingBytesBuffer = await retry(async () => { - return await this._readByteRange(this._buffer.length, endOffset + extraAllocationSize - this._buffer.length); - }, this.retryOptions); + const missingBytesBuffer = await this._readByteRange( + this._buffer.length, + endOffset + extraAllocationSize - this._buffer.length + ); this._addToBuffer(missingBytesBuffer); }); @@ -48,6 +48,9 @@ export class GgufFsFileReader extends GgufFileReader { private async _readByteRange(start: number, length: number) { const fd = await fs.open(this.filePath, "r"); try { + if (this._signal?.aborted) + throw this._signal.reason; + const buffer = Buffer.alloc(length); await fd.read(buffer, 0, length, start); return buffer; diff --git a/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts b/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts index 0a2abee6..6f07391f 100644 --- a/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts +++ b/src/gguf/fileReaders/GgufNetworkFetchFileReader.ts @@ -1,25 +1,28 @@ import retry from "async-retry"; import {withLock} from "lifecycle-utils"; import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {defaultExtraAllocationSize, ggufDefaultRetryOptions} from "../consts.js"; +import {defaultExtraAllocationSize, ggufDefaultFetchRetryOptions} from "../consts.js"; import {GgufFileReader} from "./GgufFileReader.js"; type GgufFetchFileReaderOptions = { url: string, retryOptions?: retry.Options, - headers?: Record + headers?: Record, + signal?: AbortSignal }; export class GgufNetworkFetchFileReader extends GgufFileReader { public readonly url: string; public readonly retryOptions: retry.Options; public readonly headers: Record; + private readonly _signal?: AbortSignal; - public constructor({url, retryOptions = ggufDefaultRetryOptions, headers}: GgufFetchFileReaderOptions) { + public constructor({url, retryOptions = ggufDefaultFetchRetryOptions, headers, signal}: GgufFetchFileReaderOptions) { super(); this.url = url; this.retryOptions = retryOptions; this.headers = headers ?? {}; + this._signal = signal; } public async readByteRange(offset: number | GgufReadOffset, length: number) { @@ -35,13 +38,26 @@ export class GgufNetworkFetchFileReader extends GgufFileReader { } private async _fetchToExpandBufferUpToOffset(endOffset: number, extraAllocationSize: number = defaultExtraAllocationSize) { - await withLock(this, "modifyBuffer", async () => { + await withLock(this, "modifyBuffer", this._signal, async () => { if (endOffset < this._buffer.length) return; - const missingBytesBuffer = await retry(async () => { - return await this._fetchByteRange(this._buffer.length, endOffset + extraAllocationSize - this._buffer.length); + const missingBytesBuffer = await retry(async (bail) => { + try { + return await this._fetchByteRange(this._buffer.length, endOffset + extraAllocationSize - this._buffer.length); + } catch (err) { + if (this._signal?.aborted) { + bail(this._signal.reason); + throw this._signal.reason; + } + + throw err; + } }, this.retryOptions); + + if (this._signal?.aborted) + throw this._signal.reason; + this._addToBuffer(missingBytesBuffer); }); } @@ -52,7 +68,8 @@ export class GgufNetworkFetchFileReader extends GgufFileReader { ...this.headers, Range: `bytes=${start}-${start + length}`, accept: "*/*" - } + }, + signal: this._signal }); if (!response.ok) diff --git a/src/gguf/readGgufFileInfo.ts b/src/gguf/readGgufFileInfo.ts index e39d93eb..135620e6 100644 --- a/src/gguf/readGgufFileInfo.ts +++ b/src/gguf/readGgufFileInfo.ts @@ -2,7 +2,7 @@ import retry from "async-retry"; import {parseGguf} from "./parser/parseGguf.js"; import {GgufNetworkFetchFileReader} from "./fileReaders/GgufNetworkFetchFileReader.js"; import {GgufFsFileReader} from "./fileReaders/GgufFsFileReader.js"; -import {ggufDefaultRetryOptions} from "./consts.js"; +import {ggufDefaultFetchRetryOptions} from "./consts.js"; /** @@ -12,26 +12,53 @@ import {ggufDefaultRetryOptions} from "./consts.js"; export async function readGgufFileInfo(pathOrUrl: string, { readTensorInfo = true, sourceType, - retryOptions = ggufDefaultRetryOptions, ignoreKeys = [], - logWarnings = true + logWarnings = true, + fetchRetryOptions = ggufDefaultFetchRetryOptions, + fetchHeaders = {}, + signal }: { + /** + * Whether to read the tensor info from the file's header + * Enabled by default. + */ readTensorInfo?: boolean, + + /** + * Set to a specific value to force it to only use that source type. + * By default, it detects whether the path is a network URL or a filesystem path and uses the appropriate reader accordingly. + */ sourceType?: "network" | "filesystem", - retryOptions?: retry.Options, + + /** + * Metadata keys to ignore when parsing the metadata. + * For example, `["tokenizer.ggml.tokens"]` + */ ignoreKeys?: string[], - logWarnings?: boolean + + /** Whether to log warnings */ + logWarnings?: boolean, + + /** Relevant only when fetching from a network */ + fetchRetryOptions?: retry.Options, + + /** Relevant only when fetching from a network */ + fetchHeaders?: Record, + + signal?: AbortSignal } = {}) { function createFileReader() { if (sourceType === "network" || (sourceType == null && (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")))) { return new GgufNetworkFetchFileReader({ url: pathOrUrl, - retryOptions: retryOptions + retryOptions: fetchRetryOptions, + headers: fetchHeaders, + signal }); } else if (sourceType === "filesystem" || sourceType == null) { return new GgufFsFileReader({ filePath: pathOrUrl, - retryOptions: retryOptions + signal }); } diff --git a/src/gguf/types/GgufFileInfoTypes.ts b/src/gguf/types/GgufFileInfoTypes.ts index d48efb6f..2cbd22bb 100644 --- a/src/gguf/types/GgufFileInfoTypes.ts +++ b/src/gguf/types/GgufFileInfoTypes.ts @@ -1,8 +1,8 @@ -import {GgufReadOffset} from "../utils/GgufReadOffset.js"; -import {GgufFileReader} from "../fileReaders/GgufFileReader.js"; -import {MergeOptionalUnionTypes} from "../../utils/mergeUnionTypes.js"; -import {GgufArchitectureType, GgufMetadata} from "./GgufMetadataTypes.js"; -import {GgufTensorInfo} from "./GgufTensorInfoTypes.js"; +import type {GgufReadOffset} from "../utils/GgufReadOffset.js"; +import type {GgufFileReader} from "../fileReaders/GgufFileReader.js"; +import type {MergeOptionalUnionTypes} from "../../utils/mergeUnionTypes.js"; +import type {GgufArchitectureType, GgufMetadata} from "./GgufMetadataTypes.js"; +import type {GgufTensorInfo} from "./GgufTensorInfoTypes.js"; export type MetadataValue = string | number | bigint | boolean | MetadataValue[]; export type MetadataKeyValueRecord = Record; @@ -11,19 +11,19 @@ export type MetadataNestedObject = { }; export type GgufFileInfo = { - version: 2 | 3 | number, - tensorCount: number | bigint, - metadata: GgufMetadata, - metadataSize: number, + readonly version: 2 | 3 | number, + readonly tensorCount: number | bigint, + readonly metadata: GgufMetadata, + readonly metadataSize: number, /** Same value as `metadata[metadata.general.architecture]`, but with merged types for convenience */ - architectureMetadata: MergeOptionalUnionTypes>, + readonly architectureMetadata: MergeOptionalUnionTypes>, /** can be null if `readTensorInfo` is set to `false` */ - tensorInfo?: GgufTensorInfo[], + readonly tensorInfo?: GgufTensorInfo[], /** can be null if `readTensorInfo` is set to `false` */ - tensorInfoSize?: number + readonly tensorInfoSize?: number }; diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts index 352a684f..923844ba 100644 --- a/src/gguf/types/GgufMetadataTypes.ts +++ b/src/gguf/types/GgufMetadataTypes.ts @@ -29,16 +29,16 @@ export const enum GgufArchitectureType { } export type GgufMetadata = { - general: GgufMetadataGeneral, - tokenizer: GgufMetadataTokenizer + readonly general: GgufMetadataGeneral, + readonly tokenizer: GgufMetadataTokenizer } & ( GgufArchitectureType extends A ? { - [key in GgufArchitectureType]?: key extends keyof GgufMetadataLlmToType + readonly [key in GgufArchitectureType]?: key extends keyof GgufMetadataLlmToType ? GgufMetadataLlmToType[key] : GgufMetadataLlmDefaultArchitectureType } : { - [key in A]: key extends keyof GgufMetadataLlmToType + readonly [key in A]: key extends keyof GgufMetadataLlmToType ? GgufMetadataLlmToType[key] : GgufMetadataLlmDefaultArchitectureType } @@ -94,7 +94,7 @@ export enum GgufFileType { export type GgufMetadataGeneral = { - architecture: A, + readonly architecture: A, /** * The version of the quantization format. Not required if the model is not @@ -104,7 +104,7 @@ export type GgufMetadataGeneral( diff --git a/src/gguf/types/GgufTensorInfoTypes.ts b/src/gguf/types/GgufTensorInfoTypes.ts index 082d0260..d144935e 100644 --- a/src/gguf/types/GgufTensorInfoTypes.ts +++ b/src/gguf/types/GgufTensorInfoTypes.ts @@ -1,8 +1,8 @@ export type GgufTensorInfo = { - name: string, - dimensions: (number | bigint)[], - ggmlType: GgmlType, - offset: number | bigint + readonly name: string, + readonly dimensions: readonly (number | bigint)[], + readonly ggmlType: GgmlType, + readonly offset: number | bigint }; export const enum GgmlType { diff --git a/src/index.ts b/src/index.ts index 4f0503fd..728dc395 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,7 +2,7 @@ import {DisposedError} from "lifecycle-utils"; import {Llama} from "./bindings/Llama.js"; import {getLlama, LlamaOptions} from "./bindings/getLlama.js"; import {NoBinaryFoundError} from "./bindings/utils/NoBinaryFoundError.js"; -import {LlamaLogLevel} from "./bindings/types.js"; +import {LlamaLogLevel, LlamaLogLevelGreaterThan, LlamaLogLevelGreaterThanOrEqual} from "./bindings/types.js"; import {LlamaModel, LlamaModelInfillTokens, type LlamaModelOptions, LlamaModelTokens} from "./evaluator/LlamaModel.js"; import {LlamaGrammar, type LlamaGrammarOptions} from "./evaluator/LlamaGrammar.js"; import {LlamaJsonSchemaGrammar} from "./evaluator/LlamaJsonSchemaGrammar.js"; @@ -14,7 +14,7 @@ import { } from "./evaluator/LlamaEmbeddingContext.js"; import { type LlamaContextOptions, type BatchingOptions, type LlamaContextSequenceRepeatPenalty, type CustomBatchingDispatchSchedule, - type CustomBatchingPrioritizeStrategy, type BatchItem, type PrioritizedBatchItem, type ContextShiftOptions, + type CustomBatchingPrioritizationStrategy, type BatchItem, type PrioritizedBatchItem, type ContextShiftOptions, type ContextTokensDeleteRange, type EvaluationPriority } from "./evaluator/LlamaContext/types.js"; import { @@ -29,7 +29,9 @@ import { import { LlamaCompletion, type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, type LlamaInfillGenerationOptions } from "./evaluator/LlamaCompletion.js"; +import {TokenMeter, type TokenMeterState} from "./evaluator/TokenMeter.js"; import {UnsupportedError} from "./utils/UnsupportedError.js"; +import {InsufficientMemoryError} from "./utils/InsufficientMemoryError.js"; import {ChatWrapper, type ChatWrapperSettings} from "./ChatWrapper.js"; import {EmptyChatWrapper} from "./chatWrappers/EmptyChatWrapper.js"; import {LlamaChatWrapper} from "./chatWrappers/LlamaChatWrapper.js"; @@ -83,7 +85,7 @@ export { type LlamaContextOptions, type BatchingOptions, type CustomBatchingDispatchSchedule, - type CustomBatchingPrioritizeStrategy, + type CustomBatchingPrioritizationStrategy, type BatchItem, type PrioritizedBatchItem, type ContextShiftOptions, @@ -111,7 +113,10 @@ export { type LlamaCompletionOptions, type LlamaCompletionGenerationOptions, type LlamaInfillGenerationOptions, + TokenMeter, + type TokenMeterState, UnsupportedError, + InsufficientMemoryError, DisposedError, ChatWrapper, type ChatWrapperSettings, @@ -157,5 +162,7 @@ export { type GbnfJsonEnumSchema, type GbnfJsonOneOfSchema, type GbnfJsonObjectSchema, - type GbnfJsonArraySchema + type GbnfJsonArraySchema, + LlamaLogLevelGreaterThan, + LlamaLogLevelGreaterThanOrEqual }; diff --git a/src/utils/InsufficientMemoryError.ts b/src/utils/InsufficientMemoryError.ts new file mode 100644 index 00000000..78674f29 --- /dev/null +++ b/src/utils/InsufficientMemoryError.ts @@ -0,0 +1,5 @@ +export class InsufficientMemoryError extends Error { + public constructor(message: string = "Insufficient memory") { + super(message); + } +} diff --git a/src/utils/findBestOption.ts b/src/utils/findBestOption.ts new file mode 100644 index 00000000..a14b2f16 --- /dev/null +++ b/src/utils/findBestOption.ts @@ -0,0 +1,21 @@ +export function findBestOption({generator, score}: { + generator: () => Generator, + score: (option: O) => number | null +}) { + let bestOption: O | null = null; + let bestScore: number | null = null; + + for (const option of generator()) { + const currentScore = score(option); + + if (currentScore === Infinity) + return option; + + if (currentScore != null && (bestScore == null || currentScore > bestScore)) { + bestOption = option; + bestScore = currentScore; + } + } + + return bestOption; +} diff --git a/src/utils/parseModelTypeDescription.ts b/src/utils/parseModelTypeDescription.ts deleted file mode 100644 index c84d894c..00000000 --- a/src/utils/parseModelTypeDescription.ts +++ /dev/null @@ -1,11 +0,0 @@ -import type {AddonModelArchName, AddonModelFileTypeName, AddonModelTypeName, ModelTypeDescription} from "../bindings/AddonTypes.js"; - -export function parseModelTypeDescription(modelTypeDescription: ModelTypeDescription) { - const [arch, type, ...fileTypeParts] = modelTypeDescription.split(" "); - - return { - arch: arch as AddonModelArchName, - type: type as AddonModelTypeName, - fileType: fileTypeParts.join(" ") as AddonModelFileTypeName - }; -} diff --git a/src/utils/resolveChatWrapper.ts b/src/utils/resolveChatWrapper.ts index 122fe2f1..ef3a3693 100644 --- a/src/utils/resolveChatWrapper.ts +++ b/src/utils/resolveChatWrapper.ts @@ -8,7 +8,7 @@ export function resolveChatWrapper(chatWrapper: "auto" | ChatWrapper, model: Lla const chatWrapper = resolveChatWrapperBasedOnModel({ bosString: model.tokens.bosString, filename: model.filename, - typeDescription: model.typeDescription + fileInfo: model.fileInfo }); if (chatWrapper != null) diff --git a/test/modelDependent/functionary/chatSession.test.ts b/test/modelDependent/functionary/chatSession.test.ts index 28b8d306..228ee1ba 100644 --- a/test/modelDependent/functionary/chatSession.test.ts +++ b/test/modelDependent/functionary/chatSession.test.ts @@ -35,5 +35,102 @@ describe("functionary", () => { expect(res2).to.eql("6+6 equals 12."); }); + + test("disposing a context sequences removes the current state", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 4096 + }); + const contextSequence = context.getSequence(); + const chatSession = new LlamaChatSession({ + contextSequence, + autoDisposeSequence: false + }); + + const res = await chatSession.prompt("How much is 6+6"); + + expect(res).to.eql("6+6 equals 12."); + const tokenMeterState = contextSequence.tokenMeter.getState(); + expect(tokenMeterState).to.toMatchInlineSnapshot(` + { + "usedInputTokens": 140, + "usedOutputTokens": 17, + "usedRestoreStateTokens": 0, + } + `); + + chatSession.dispose(); + contextSequence.dispose(); + + const contextSequence2 = context.getSequence(); + const chatSession2 = new LlamaChatSession({ + contextSequence: contextSequence2 + }); + + const res2 = await chatSession2.prompt("How much is 6+6+6"); + + const tokenMeterState2 = contextSequence2.tokenMeter.getState(); + expect(tokenMeterState2).to.toMatchInlineSnapshot(` + { + "usedInputTokens": 142, + "usedOutputTokens": 19, + "usedRestoreStateTokens": 0, + } + `); + expect(tokenMeterState2.usedInputTokens).to.be.greaterThanOrEqual(tokenMeterState.usedInputTokens); + expect(res2).to.eql("6+6+6 equals 18."); + }); + + test("reusing a context sequences utilizes existing state", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + const context = await model.createContext({ + contextSize: 4096 + }); + const contextSequence = context.getSequence(); + const chatSession = new LlamaChatSession({ + contextSequence, + autoDisposeSequence: false + }); + + const res = await chatSession.prompt("How much is 6+6"); + + expect(res).to.eql("6+6 equals 12."); + const tokenMeterState = contextSequence.tokenMeter.getState(); + expect(tokenMeterState).to.toMatchInlineSnapshot(` + { + "usedInputTokens": 140, + "usedOutputTokens": 17, + "usedRestoreStateTokens": 0, + } + `); + + chatSession.dispose(); + const chatSession2 = new LlamaChatSession({ + contextSequence + }); + + const res2 = await chatSession2.prompt("How much is 6+6+6"); + + const tokenMeterStateDiff = contextSequence.tokenMeter.diff(tokenMeterState); + expect(tokenMeterStateDiff).to.toMatchInlineSnapshot(` + { + "usedInputTokens": 25, + "usedOutputTokens": 19, + "usedRestoreStateTokens": 0, + } + `); + expect(tokenMeterStateDiff.usedInputTokens).to.be.lessThan(tokenMeterState.usedInputTokens); + expect(res2).to.eql("6+6+6 equals 18."); + }); }); }); diff --git a/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts b/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts new file mode 100644 index 00000000..b7af73f0 --- /dev/null +++ b/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts @@ -0,0 +1,627 @@ +import {describe, expect, it} from "vitest"; +import {getModelFile} from "../../utils/modelFiles.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; +import {LlamaModelOptions, resolveModelGpuLayersOption} from "../../../src/evaluator/LlamaModel.js"; +import {readGgufFileInfo} from "../../../src/gguf/readGgufFileInfo.js"; +import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; +import {defaultLlamaVramPadding} from "../../../src/bindings/getLlama.js"; +import {BuildGpu} from "../../../src/bindings/types.js"; +import {resolveContextContextSizeOption} from "../../../src/evaluator/LlamaContext/LlamaContext.js"; + +describe("functionary", () => { + describe("model options", () => { + describe("Resolve the correct number of GPU layers", async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + const llama = await getTestLlama(); + + const fileInfo = await readGgufFileInfo(modelPath); + const ggufInsights = await GgufInsights.from(fileInfo, llama); + + const s1GB = Math.pow(1024, 3); + + function resolveGpuLayers(gpuLayers: LlamaModelOptions["gpuLayers"], { + totalVram, freeVram, ignoreMemorySafetyChecks = false, llamaGpu = "metal" + }: { + totalVram: number, freeVram: number, ignoreMemorySafetyChecks?: boolean, llamaGpu?: BuildGpu + }) { + const resolvedGpuLayers = resolveModelGpuLayersOption(gpuLayers, { + ggufInsights, + ignoreMemorySafetyChecks, + getVramState: () => ({ + total: llamaGpu === false ? 0 : totalVram, + free: llamaGpu === false ? 0 : freeVram + }), + llamaVramPaddingSize: defaultLlamaVramPadding(llamaGpu === false ? 0 : totalVram), + llamaGpu, + llamaSupportsGpuOffloading: llamaGpu !== false + }); + + function resolveAutoContextSize() { + const modelVram = ggufInsights.estimateModelResourceRequirements({ + gpuLayers: resolvedGpuLayers + }).gpuVram; + + try { + return resolveContextContextSizeOption({ + contextSize: "auto", + batchSize: undefined, + sequences: 1, + modelFileInsights: ggufInsights, + modelGpuLayers: resolvedGpuLayers, + modelTrainContextSize: ggufInsights.trainContextSize ?? 4096, + getVramState: () => ({ + total: llamaGpu === false ? 0 : totalVram, + free: llamaGpu === false ? 0 : (freeVram - modelVram) + }), + llamaGpu, + ignoreMemorySafetyChecks: false, + isEmbeddingContext: false + }); + } catch (err) { + return null; + } + } + + return { + gpuLayers: resolvedGpuLayers, + contextSize: resolveAutoContextSize() + }; + } + + it("attempts to resolve 0 gpuLayers", () => { + { + const res = resolveGpuLayers(0, { + totalVram: s1GB * 6, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers(0, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + + { + const res = resolveGpuLayers(0, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + }); + + it("attempts to resolve 16 gpuLayers", () => { + { + const res = resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 3 + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.toMatchInlineSnapshot("8037"); + } + try { + resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + try { + resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(16, { + totalVram: s1GB * 6, + + // play with this number to make the test pass, it should be low enough so that there won't be any VRAM left + // to create a context + freeVram: s1GB * 0.2, + + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.eql(null); + } + + + { + const res = resolveGpuLayers(16, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers(16, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + }); + + it("attempts to resolve 32 gpuLayers", () => { + { + const res = resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.eql(32); + expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + } + try { + resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 0, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(32); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + } + + { + const res = resolveGpuLayers(32, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers(32, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + }); + + it("attempts to resolve 33 gpuLayers", () => { + { + const res = resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + } + try { + resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + } + + { + const res = resolveGpuLayers(33, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers(33, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + }); + + it('attempts to resolve "max"', () => { + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 1.2, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + }{ + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("633"); + } + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.4 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("2878"); + } + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.8 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("5123"); + } + }); + + it('attempts to resolve "auto"', () => { + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("1"); + expect(res.contextSize).to.toMatchInlineSnapshot("7608"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 1.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("5"); + expect(res.contextSize).to.toMatchInlineSnapshot("7964"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 2.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("11"); + expect(res.contextSize).to.toMatchInlineSnapshot("9310"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("16"); + expect(res.contextSize).to.toMatchInlineSnapshot("8891"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.3 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); + expect(res.contextSize).to.toMatchInlineSnapshot("8118"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.5 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("19"); + expect(res.contextSize).to.toMatchInlineSnapshot("8544"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("21"); + expect(res.contextSize).to.toMatchInlineSnapshot("8590"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); + expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.3 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("24"); + expect(res.contextSize).to.toMatchInlineSnapshot("8988"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.5 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("3439"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("5123"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 5.2 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("7368"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 5.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("10736"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + } + }); + + it("attempts to resolve {min?: number, max?: number}", () => { + { + const res = resolveGpuLayers({max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + { + const res = resolveGpuLayers({min: 0, max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + } + try { + resolveGpuLayers({min: 2}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + try { + resolveGpuLayers({min: 2, max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + { + const res = resolveGpuLayers({max: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.toMatchInlineSnapshot("16575"); + } + try { + resolveGpuLayers({min: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers({min: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); + expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + } + { + const res = resolveGpuLayers({min: 16, max: 24}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.be.lte(24); + expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); + expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + } + { + const res = resolveGpuLayers({min: 16, max: 24}, { + totalVram: s1GB * 6, + freeVram: s1GB * 3 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.be.lte(24); + expect(res.gpuLayers).to.toMatchInlineSnapshot("16"); + expect(res.contextSize).to.toMatchInlineSnapshot("8037"); + } + }); + + it("attempts to resolve {fitContext?: {contextSize?: number}}", () => { + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("26"); + expect(res.contextSize).to.toMatchInlineSnapshot("5142"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 2, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("4"); + expect(res.contextSize).to.toMatchInlineSnapshot("4385"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); + expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 1, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("2"); + expect(res.contextSize).to.toMatchInlineSnapshot("8497"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 0, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + try { + resolveGpuLayers({min: 1, fitContext: {contextSize: 8192}}, { + totalVram: s1GB * 0.2, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + } + { + const contextSize = 16384; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("32768"); + expect(res.contextSize).to.be.gte(contextSize); + } + }); + }); + }); +}); diff --git a/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts new file mode 100644 index 00000000..6c6b091f --- /dev/null +++ b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts @@ -0,0 +1,627 @@ +import {describe, expect, it} from "vitest"; +import {getModelFile} from "../../utils/modelFiles.js"; +import {getTestLlama} from "../../utils/getTestLlama.js"; +import {LlamaModelOptions, resolveModelGpuLayersOption} from "../../../src/evaluator/LlamaModel.js"; +import {readGgufFileInfo} from "../../../src/gguf/readGgufFileInfo.js"; +import {GgufInsights} from "../../../src/gguf/GgufInsights.js"; +import {defaultLlamaVramPadding} from "../../../src/bindings/getLlama.js"; +import {BuildGpu} from "../../../src/bindings/types.js"; +import {resolveContextContextSizeOption} from "../../../src/evaluator/LlamaContext/LlamaContext.js"; + +describe("stableCode", () => { + describe("model options", () => { + describe("Resolve the correct number of GPU layers", async () => { + const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); + const llama = await getTestLlama(); + + const fileInfo = await readGgufFileInfo(modelPath); + const ggufInsights = await GgufInsights.from(fileInfo, llama); + + const s1GB = Math.pow(1024, 3); + + function resolveGpuLayers(gpuLayers: LlamaModelOptions["gpuLayers"], { + totalVram, freeVram, ignoreMemorySafetyChecks = false, llamaGpu = "metal" + }: { + totalVram: number, freeVram: number, ignoreMemorySafetyChecks?: boolean, llamaGpu?: BuildGpu + }) { + const resolvedGpuLayers = resolveModelGpuLayersOption(gpuLayers, { + ggufInsights, + ignoreMemorySafetyChecks, + getVramState: () => ({ + total: llamaGpu === false ? 0 : totalVram, + free: llamaGpu === false ? 0 : freeVram + }), + llamaVramPaddingSize: defaultLlamaVramPadding(llamaGpu === false ? 0 : totalVram), + llamaGpu, + llamaSupportsGpuOffloading: llamaGpu !== false + }); + + function resolveAutoContextSize() { + const modelVram = ggufInsights.estimateModelResourceRequirements({ + gpuLayers: resolvedGpuLayers + }).gpuVram; + + try { + return resolveContextContextSizeOption({ + contextSize: "auto", + batchSize: undefined, + sequences: 1, + modelFileInsights: ggufInsights, + modelGpuLayers: resolvedGpuLayers, + modelTrainContextSize: ggufInsights.trainContextSize ?? 4096, + getVramState: () => ({ + total: llamaGpu === false ? 0 : totalVram, + free: llamaGpu === false ? 0 : (freeVram - modelVram) + }), + llamaGpu, + ignoreMemorySafetyChecks: false, + isEmbeddingContext: false + }); + } catch (err) { + return null; + } + } + + return { + gpuLayers: resolvedGpuLayers, + contextSize: resolveAutoContextSize() + }; + } + + it("attempts to resolve 0 gpuLayers", () => { + { + const res = resolveGpuLayers(0, { + totalVram: s1GB * 6, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers(0, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + + { + const res = resolveGpuLayers(0, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + }); + + it("attempts to resolve 16 gpuLayers", () => { + { + const res = resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 3 + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.toMatchInlineSnapshot("9530"); + } + try { + resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + try { + resolveGpuLayers(16, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(16, { + totalVram: s1GB * 6, + + // play with this number to make the test pass, it should be low enough so that there won't be any VRAM left + // to create a context + freeVram: s1GB * 0.2, + + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.eql(null); + } + + + { + const res = resolveGpuLayers(16, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers(16, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + }); + + it("attempts to resolve 32 gpuLayers", () => { + { + const res = resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.eql(32); + expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + } + try { + resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(32, { + totalVram: s1GB * 6, + freeVram: s1GB * 0, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(32); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + } + + { + const res = resolveGpuLayers(32, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers(32, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + }); + + it("attempts to resolve 33 gpuLayers", () => { + { + const res = resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + } + try { + resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + { + const res = resolveGpuLayers(33, { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + } + + { + const res = resolveGpuLayers(33, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers(33, { + totalVram: 0, + freeVram: 0, + llamaGpu: false, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + }); + + it('attempts to resolve "max"', () => { + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + try { + resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[AssertionError: expected \"Should have thrown an error\" not to be reached]"); + } + + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 1.2, + ignoreMemorySafetyChecks: true + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("null"); + }{ + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + } + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.4 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("6995"); + } + { + const res = resolveGpuLayers("max", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.8 + }); + expect(res.gpuLayers).to.eql(33); + expect(res.contextSize).to.toMatchInlineSnapshot("8138"); + } + }); + + it('attempts to resolve "auto"', () => { + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 0.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("4"); + expect(res.contextSize).to.toMatchInlineSnapshot("3742"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 1.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("10"); + expect(res.contextSize).to.toMatchInlineSnapshot("4249"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 2.4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("1282"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("3282"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.3 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("3853"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.5 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("4424"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 3.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("5281"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.3 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("6710"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.5 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("7281"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 4.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("8138"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 5.2 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("9280"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 5.8 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("10994"); + } + { + const res = resolveGpuLayers("auto", { + totalVram: s1GB * 6, + freeVram: s1GB * 6 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + } + }); + + it("attempts to resolve {min?: number, max?: number}", () => { + { + const res = resolveGpuLayers({max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + { + const res = resolveGpuLayers({min: 0, max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + } + try { + resolveGpuLayers({min: 2}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + try { + resolveGpuLayers({min: 2, max: 4}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + + { + const res = resolveGpuLayers({max: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.eql(16); + expect(res.contextSize).to.toMatchInlineSnapshot("14593"); + } + try { + resolveGpuLayers({min: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 2 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[AssertionError: expected \"Should have thrown an error\" not to be reached]"); + } + { + const res = resolveGpuLayers({min: 16}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + } + { + const res = resolveGpuLayers({min: 16, max: 24}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.be.lte(24); + expect(res.gpuLayers).to.toMatchInlineSnapshot("24"); + expect(res.contextSize).to.toMatchInlineSnapshot("9020"); + } + { + const res = resolveGpuLayers({min: 16, max: 24}, { + totalVram: s1GB * 6, + freeVram: s1GB * 3 + }); + expect(res.gpuLayers).to.be.gte(16); + expect(res.gpuLayers).to.be.lte(24); + expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); + expect(res.contextSize).to.toMatchInlineSnapshot("8208"); + } + }); + + it("attempts to resolve {fitContext?: {contextSize?: number}}", () => { + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: 0, + freeVram: 0, + llamaGpu: false + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); + expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 4096; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 2, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("5"); + expect(res.contextSize).to.toMatchInlineSnapshot("4959"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 4 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("25"); + expect(res.contextSize).to.toMatchInlineSnapshot("8537"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 1, + freeVram: s1GB * 1 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("2"); + expect(res.contextSize).to.toMatchInlineSnapshot("9618"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + const contextSize = 8192; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 0, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.toMatchInlineSnapshot("0"); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + expect(res.contextSize).to.be.gte(contextSize); + } + { + try { + resolveGpuLayers({min: 1, fitContext: {contextSize: 8192}}, { + totalVram: s1GB * 0.2, + freeVram: s1GB * 0 + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(err).toMatchInlineSnapshot("[Error: Not enough VRAM to fit the model with the specified settings]"); + } + } + { + const contextSize = 16384; + const res = resolveGpuLayers({fitContext: {contextSize}}, { + totalVram: s1GB * 6, + freeVram: s1GB * 0 + }); + expect(res.gpuLayers).to.eql(0); + expect(res.contextSize).to.toMatchInlineSnapshot("16384"); + expect(res.contextSize).to.be.gte(contextSize); + } + }); + }); + }); +}); From 87ff5e8971aaff012f88d26a88da0554be4ed799 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 01:36:29 +0200 Subject: [PATCH 24/52] feat: simplify `chat`, `complete` and `infill` commands, list GPU device names --- .vitepress/utils/getCommandHtmlDoc.ts | 8 +- llama/addon.cpp | 29 +++- llama/gpuInfo/cuda-gpu-info.cu | 20 +++ llama/gpuInfo/cuda-gpu-info.h | 2 + llama/gpuInfo/metal-gpu-info.h | 4 +- llama/gpuInfo/metal-gpu-info.mm | 14 +- llama/gpuInfo/vulkan-gpu-info.cpp | 20 ++- llama/gpuInfo/vulkan-gpu-info.h | 2 + src/bindings/AddonTypes.ts | 3 + src/bindings/Llama.ts | 8 + src/cli/commands/ChatCommand.ts | 212 +++++++++++++++++++------- src/cli/commands/CompleteCommand.ts | 179 ++++++++++++++++------ src/cli/commands/InfillCommand.ts | 179 ++++++++++++++++------ src/cli/utils/printInfoLine.ts | 66 ++++++++ src/utils/withOra.ts | 7 +- 15 files changed, 591 insertions(+), 162 deletions(-) create mode 100644 src/cli/utils/printInfoLine.ts diff --git a/.vitepress/utils/getCommandHtmlDoc.ts b/.vitepress/utils/getCommandHtmlDoc.ts index c2041d99..b9cae0ac 100644 --- a/.vitepress/utils/getCommandHtmlDoc.ts +++ b/.vitepress/utils/getCommandHtmlDoc.ts @@ -209,8 +209,12 @@ function renderOptionsGroupOptionsTable(options: {name: string, option: Options} let optionDescription: string[] = option.description != null ? [htmlEscape(option.description)] : []; - if (option.default != null) { - optionDescription.push(`(${htmlEscape("default: ")}${htmlEscape(option.default)})`); + const hasDefaultDescription = option.defaultDescription != null && option.defaultDescription.trim().length > 0; + if (option.default != null || hasDefaultDescription) { + if (hasDefaultDescription && option.defaultDescription != null) + optionDescription.push(`(${htmlEscape("default: ")}${htmlEscape(option.defaultDescription.trim())})`); + else + optionDescription.push(`(${htmlEscape("default: ")}${htmlEscape(option.default)})`); } if (option.type != null) { diff --git a/llama/addon.cpp b/llama/addon.cpp index 060635c4..273adfca 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -161,7 +161,7 @@ Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { #ifdef GPU_INFO_USE_METAL uint64_t metalDeviceTotal = 0; uint64_t metalDeviceUsed = 0; - get_metal_gpu_info(&metalDeviceTotal, &metalDeviceUsed); + getMetalGpuInfo(&metalDeviceTotal, &metalDeviceUsed); total += metalDeviceTotal; used += metalDeviceUsed; @@ -174,6 +174,32 @@ Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { return result; } +Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info) { + std::vector deviceNames; + +#ifdef GPU_INFO_USE_CUBLAS + gpuInfoGetCudaDeviceNames(&deviceNames, logCudaError); +#endif + +#ifdef GPU_INFO_USE_VULKAN + gpuInfoGetVulkanDeviceNames(&deviceNames, logVulkanWarning); +#endif + +#ifdef GPU_INFO_USE_METAL + getMetalGpuDeviceNames(&deviceNames); +#endif + + Napi::Object result = Napi::Object::New(info.Env()); + + Napi::Array deviceNamesNapiArray = Napi::Array::New(info.Env(), deviceNames.size()); + for (size_t i = 0; i < deviceNames.size(); ++i) { + deviceNamesNapiArray[i] = Napi::String::New(info.Env(), deviceNames[i]); + } + result.Set("deviceNames", deviceNamesNapiArray); + + return result; +} + Napi::Value getGpuType(const Napi::CallbackInfo& info) { #ifdef GPU_INFO_USE_CUBLAS return Napi::String::New(info.Env(), "cuda"); @@ -1769,6 +1795,7 @@ Napi::Object registerCallback(Napi::Env env, Napi::Object exports) { Napi::PropertyDescriptor::Function("setLogger", setLogger), Napi::PropertyDescriptor::Function("setLoggerLogLevel", setLoggerLogLevel), Napi::PropertyDescriptor::Function("getGpuVramInfo", getGpuVramInfo), + Napi::PropertyDescriptor::Function("getGpuDeviceInfo", getGpuDeviceInfo), Napi::PropertyDescriptor::Function("getGpuType", getGpuType), Napi::PropertyDescriptor::Function("init", addonInit), Napi::PropertyDescriptor::Function("dispose", addonDispose), diff --git a/llama/gpuInfo/cuda-gpu-info.cu b/llama/gpuInfo/cuda-gpu-info.cu index 62a9bd89..d681b46f 100644 --- a/llama/gpuInfo/cuda-gpu-info.cu +++ b/llama/gpuInfo/cuda-gpu-info.cu @@ -1,4 +1,5 @@ #include +#include #if defined(GPU_INFO_USE_HIPBLAS) #include @@ -97,3 +98,22 @@ bool gpuInfoGetTotalCudaDevicesInfo(size_t * total, size_t * used, gpuInfoCudaEr return true; } + +void gpuInfoGetCudaDeviceNames(std::vector * deviceNames, gpuInfoCudaErrorLogCallback_t errorLogCallback) { + int deviceCount = gpuInfoGetCudaDeviceCount(errorLogCallback); + + if (deviceCount < 0) { + return; + } + + for (int i = 0; i < deviceCount; i++) { + cudaDeviceProp prop; + auto getDevicePropertiesResult = cudaGetDeviceProperties(&prop, i); + + if (getDevicePropertiesResult != cudaSuccess) { + errorLogCallback(cudaGetErrorString(getDevicePropertiesResult)); + } else { + (*deviceNames)->push_back(std::string(prop.name)); + } + } +} diff --git a/llama/gpuInfo/cuda-gpu-info.h b/llama/gpuInfo/cuda-gpu-info.h index dfd0bbdd..570bdc69 100644 --- a/llama/gpuInfo/cuda-gpu-info.h +++ b/llama/gpuInfo/cuda-gpu-info.h @@ -1,7 +1,9 @@ #pragma once #include +#include typedef void (*gpuInfoCudaErrorLogCallback_t)(const char* message); bool gpuInfoGetTotalCudaDevicesInfo(size_t * total, size_t * used, gpuInfoCudaErrorLogCallback_t errorLogCallback); +void gpuInfoGetCudaDeviceNames(std::vector * deviceNames, gpuInfoCudaErrorLogCallback_t errorLogCallback); diff --git a/llama/gpuInfo/metal-gpu-info.h b/llama/gpuInfo/metal-gpu-info.h index 29b9d022..07b0d56b 100644 --- a/llama/gpuInfo/metal-gpu-info.h +++ b/llama/gpuInfo/metal-gpu-info.h @@ -1,5 +1,7 @@ #pragma once #include +#include -void get_metal_gpu_info(uint64_t * total, uint64_t * used); +void getMetalGpuInfo(uint64_t * total, uint64_t * used); +void getMetalGpuDeviceNames(std::vector * deviceNames); \ No newline at end of file diff --git a/llama/gpuInfo/metal-gpu-info.mm b/llama/gpuInfo/metal-gpu-info.mm index fd72688a..61c7d510 100644 --- a/llama/gpuInfo/metal-gpu-info.mm +++ b/llama/gpuInfo/metal-gpu-info.mm @@ -1,7 +1,8 @@ #include +#include #import -void get_metal_gpu_info(uint64_t * total, uint64_t * used) { +void getMetalGpuInfo(uint64_t * total, uint64_t * used) { id device = MTLCreateSystemDefaultDevice(); if (device) { @@ -15,3 +16,14 @@ void get_metal_gpu_info(uint64_t * total, uint64_t * used) { [device release]; device = nil; } + +void getMetalGpuDeviceNames(std::vector * deviceNames) { + NSArray> *devices = MTLCopyAllDevices(); + + for (id device in devices) { + (*deviceNames).push_back(std::string(([NSString stringWithUTF8String:device.name.UTF8String]).UTF8String)); + } + + [devices release]; + devices = nil; +} diff --git a/llama/gpuInfo/vulkan-gpu-info.cpp b/llama/gpuInfo/vulkan-gpu-info.cpp index e95b0582..4bdbbe7d 100644 --- a/llama/gpuInfo/vulkan-gpu-info.cpp +++ b/llama/gpuInfo/vulkan-gpu-info.cpp @@ -1,10 +1,22 @@ #include +#include #include typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message); bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { + return enumerateVulkanDevices(total, used, false, nullptr, warningLogCallback); +} + +bool gpuInfoGetVulkanDeviceNames(std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { + size_t vulkanDeviceTotal = 0; + size_t vulkanDeviceUsed = 0; + + return enumerateVulkanDevices(&vulkanDeviceTotal, &vulkanDeviceUsed, true, deviceNames, warningLogCallback); +} + +static bool enumerateVulkanDevices(size_t* total, size_t* used, bool addDeviceNames, std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { vk::ApplicationInfo appInfo("node-llama-cpp GPU info", 1, "llama.cpp", 1, VK_API_VERSION_1_2); vk::InstanceCreateInfo createInfo(vk::InstanceCreateFlags(), &appInfo, {}, {}); vk::Instance instance = vk::createInstance(createInfo); @@ -41,8 +53,14 @@ bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, gpuInfoVulkan for (uint32_t i = 0; i < memProps.memoryHeapCount; ++i) { if (memProps.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { - totalMem += memProps.memoryHeaps[i].size; + const auto size = memProps.memoryHeaps[i].size; + totalMem += size; usedMem += memoryBudgetProperties.heapUsage[i]; + + if (size > 0 && addDeviceNames) { + (*deviceNames).push_back(std::string(deviceProps.deviceName.data())); + } + break; } } diff --git a/llama/gpuInfo/vulkan-gpu-info.h b/llama/gpuInfo/vulkan-gpu-info.h index 6a2fbe40..d2457f10 100644 --- a/llama/gpuInfo/vulkan-gpu-info.h +++ b/llama/gpuInfo/vulkan-gpu-info.h @@ -1,7 +1,9 @@ #pragma once #include +#include typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message); bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, gpuInfoVulkanWarningLogCallback_t warningLogCallback); +bool gpuInfoGetVulkanDeviceNames(std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback); \ No newline at end of file diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index 7cff953c..ed13b1fa 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -53,6 +53,9 @@ export type BindingModule = { total: number, used: number }, + getGpuDeviceInfo(): { + deviceNames: string[] + }, getGpuType(): "cuda" | "vulkan" | "metal" | undefined, init(): Promise, dispose(): Promise diff --git a/src/bindings/Llama.ts b/src/bindings/Llama.ts index 20cd5ec2..a6e3d95a 100644 --- a/src/bindings/Llama.ts +++ b/src/bindings/Llama.ts @@ -193,6 +193,14 @@ export class Llama { }; } + public getGpuDeviceNames() { + this._ensureNotDisposed(); + + const {deviceNames} = this._bindings.getGpuDeviceInfo(); + + return deviceNames; + } + public async loadModel(options: LlamaModelOptions) { this._ensureNotDisposed(); diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index e46fdf70..c8a6e196 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -8,16 +8,19 @@ import bytes from "bytes"; import {chatCommandHistoryFilePath, defaultChatSystemPrompt} from "../../config.js"; import {getIsInDocumentationMode} from "../../state.js"; import {ReplHistory} from "../../utils/ReplHistory.js"; -import withStatusLogs from "../../utils/withStatusLogs.js"; import {defineChatSessionFunction} from "../../evaluator/LlamaChatSession/utils/defineChatSessionFunction.js"; import {getLlama} from "../../bindings/getLlama.js"; import {LlamaGrammar} from "../../evaluator/LlamaGrammar.js"; import {LlamaChatSession} from "../../evaluator/LlamaChatSession/LlamaChatSession.js"; import {LlamaJsonSchemaGrammar} from "../../evaluator/LlamaJsonSchemaGrammar.js"; -import {LlamaLogLevel} from "../../bindings/types.js"; +import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import { ChatWrapperTypeName, chatWrapperTypeNames, resolveChatWrapperBasedOnWrapperTypeName } from "../../bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.js"; +import withOra from "../../utils/withOra.js"; +import {TokenMeter} from "../../evaluator/TokenMeter.js"; +import {printInfoLine} from "../utils/printInfoLine.js"; +import {getPrettyBuildGpuName} from "../../bindings/consts.js"; type ChatCommand = { model: string, @@ -27,7 +30,7 @@ type ChatCommand = { prompt?: string, promptFile?: string, wrapper: ChatWrapperTypeName, - contextSize: number, + contextSize?: number, batchSize?: number, grammar: "text" | Parameters[1], jsonSchemaGrammarFile?: string, @@ -45,19 +48,20 @@ type ChatCommand = { maxTokens: number, noHistory: boolean, environmentFunctions: boolean, - noInfoLog: boolean, + debug: boolean, + meter: boolean, printTimings: boolean }; export const ChatCommand: CommandModule = { - command: "chat", + command: "chat [modelPath]", describe: "Chat with a Llama model", builder(yargs) { const isInDocumentationMode = getIsInDocumentationMode(); return yargs .option("model", { - alias: "m", + alias: ["m", "modelPath"], type: "string", demandOption: true, description: "Llama model file to use for the chat", @@ -106,8 +110,8 @@ export const ChatCommand: CommandModule = { .option("contextSize", { alias: "c", type: "number", - default: 1024 * 4, description: "Context size to use for the model context", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("batchSize", { @@ -168,6 +172,7 @@ export const ChatCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("repeatPenalty", { @@ -224,11 +229,17 @@ export const ChatCommand: CommandModule = { description: "Provide access to environment functions like `getDate` and `getTime`", group: "Optional:" }) - .option("noInfoLog", { - alias: "nl", + .option("debug", { + alias: "d", type: "boolean", default: false, - description: "Disable llama.cpp info logs", + description: "Print llama.cpp info and debug logs", + group: "Optional:" + }) + .option("meter", { + type: "boolean", + default: false, + description: "Log how many tokens were used as input and output for each response", group: "Optional:" }) .option("printTimings", { @@ -245,14 +256,14 @@ export const ChatCommand: CommandModule = { grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, - environmentFunctions, noInfoLog, printTimings + environmentFunctions, debug, meter, printTimings }) { try { await RunChat({ model, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, - noHistory, environmentFunctions, noInfoLog, printTimings + noHistory, environmentFunctions, debug, meter, printTimings }); } catch (err) { await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing @@ -267,15 +278,16 @@ async function RunChat({ model: modelArg, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, - maxTokens, noHistory, environmentFunctions, noInfoLog, printTimings + maxTokens, noHistory, environmentFunctions, debug, meter, printTimings }: ChatCommand) { - if (noInfoLog) - console.info(`${chalk.yellow("Log level:")} warn`); + if (debug) + console.info(`${chalk.yellow("Log level:")} debug`); + const llamaLogLevel = debug + ? LlamaLogLevel.debug + : LlamaLogLevel.warn; const llama = await getLlama("lastBuild", { - logLevel: noInfoLog - ? LlamaLogLevel.warn - : LlamaLogLevel.debug + logLevel: llamaLogLevel }); const logBatchSize = batchSize != null; @@ -296,18 +308,17 @@ async function RunChat({ prompt = await fs.readFile(path.resolve(process.cwd(), promptFile), "utf8"); } - if (batchSize == null) - batchSize = contextSize; - else if (batchSize > contextSize) { + if (batchSize != null && contextSize != null && batchSize > contextSize) { console.warn(chalk.yellow("Batch size is greater than the context size. Batch size will be set to the context size.")); batchSize = contextSize; } let initialPrompt = prompt ?? null; - const model = await withStatusLogs({ + const model = await withOra({ loading: chalk.blue("Loading model"), success: chalk.blue("Model loaded"), - fail: chalk.blue("Failed to load model") + fail: chalk.blue("Failed to load model"), + useStatusLogs: debug }, async () => { try { return await llama.loadModel({ @@ -321,15 +332,16 @@ async function RunChat({ } } }); - const context = await withStatusLogs({ + const context = await withOra({ loading: chalk.blue("Creating context"), success: chalk.blue("Context created"), - fail: chalk.blue("Failed to create context") + fail: chalk.blue("Failed to create context"), + useStatusLogs: debug }, async () => { try { return await model.createContext({ - contextSize, - batchSize, + contextSize: contextSize != null ? contextSize : undefined, + batchSize: batchSize != null ? batchSize : undefined, threads }); } finally { @@ -354,53 +366,120 @@ async function RunChat({ const chatWrapper = resolveChatWrapperBasedOnWrapperTypeName(wrapper, { bosString: bos, filename: model.filename, - typeDescription: model.typeDescription + fileInfo: model.fileInfo }); + const contextSequence = context.getSequence(); const session = new LlamaChatSession({ - contextSequence: context.getSequence(), + contextSequence, systemPrompt, chatWrapper: chatWrapper }); + let lastTokenMeterState = contextSequence.tokenMeter.getState(); await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing if (grammarArg != "text" && jsonSchemaGrammarFilePath != null) console.warn(chalk.yellow("Both `grammar` and `jsonSchemaGrammarFile` were specified. `jsonSchemaGrammarFile` will be used.")); - console.info(`${chalk.yellow("Context size:")} ${context.contextSize}`); - - if (logBatchSize) - console.info(`${chalk.yellow("Batch size:")} ${context.batchSize}`); - - console.info(`${chalk.yellow("Train context size:")} ${model.trainContextSize}`); - console.info(`${chalk.yellow("Model type:")} ${model.typeDescription}`); - console.info(`${chalk.yellow("Model size:")} ${bytes(model.size)}`); - console.info(`${chalk.yellow("BOS:")} ${bos}`); - console.info(`${chalk.yellow("EOS:")} ${eos}`); - console.info(`${chalk.yellow("Chat wrapper:")} ${chatWrapper.wrapperName}`); - console.info(`${chalk.yellow("Repeat penalty:")} ${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)`); - - if (repeatFrequencyPenalty != null) - console.info(`${chalk.yellow("Repeat frequency penalty:")} ${repeatFrequencyPenalty}`); - - if (repeatPresencePenalty != null) - console.info(`${chalk.yellow("Repeat presence penalty:")} ${repeatPresencePenalty}`); - - if (!penalizeRepeatingNewLine) - console.info(`${chalk.yellow("Penalize repeating new line:")} disabled`); - - if (jsonSchemaGrammarFilePath != null) - console.info(`${chalk.yellow("JSON schema grammar file:")} ${ - path.relative(process.cwd(), path.resolve(process.cwd(), jsonSchemaGrammarFilePath)) - }`); - else if (grammarArg !== "text") - console.info(`${chalk.yellow("Grammar:")} ${grammarArg}`); - if (environmentFunctions && grammar != null) { console.warn(chalk.yellow("Environment functions are disabled since a grammar is already specified")); environmentFunctions = false; } + const padTitle = "Context".length + 1; + if (llama.gpu !== false) { + printInfoLine({ + title: "GPU", + padTitle: padTitle, + info: [{ + title: "Type", + value: getPrettyBuildGpuName(llama.gpu) + }, { + title: "VRAM", + value: bytes(llama.getVramState().total) + }, { + title: "Name", + value: llama.getGpuDeviceNames().join(", ") + }, { + title: "GPU layers", + value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ + chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) + }` + }] + }); + } + printInfoLine({ + title: "Model", + padTitle: padTitle, + info: [{ + title: "Type", + value: model.typeDescription + }, { + title: "Size", + value: bytes(model.size) + }, { + title: "BOS", + value: String(bos) + }, { + title: "EOS", + value: String(eos) + }, { + title: "Train context size", + value: String(model.trainContextSize) + }] + }); + printInfoLine({ + title: "Context", + padTitle: padTitle, + info: [{ + title: "Size", + value: String(context.contextSize) + }, { + show: logBatchSize, + title: "Batch size", + value: bytes(context.batchSize) + }, { + show: meter, + title: "Token meter", + value: "enabled" + }] + }); + printInfoLine({ + title: "Chat", + padTitle: padTitle, + info: [{ + title: "Wrapper", + value: chatWrapper.wrapperName + }, { + title: "Repeat penalty", + value: `${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)` + }, { + show: repeatFrequencyPenalty != null, + title: "Repeat frequency penalty", + value: String(repeatFrequencyPenalty) + }, { + show: repeatPresencePenalty != null, + title: "Repeat presence penalty", + value: String(repeatPresencePenalty) + }, { + show: !penalizeRepeatingNewLine, + title: "Penalize repeating new line", + value: "disabled" + }, { + show: jsonSchemaGrammarFilePath != null, + title: "JSON schema grammar file", + value: () => path.relative(process.cwd(), path.resolve(process.cwd(), jsonSchemaGrammarFilePath ?? "")) + }, { + show: jsonSchemaGrammarFilePath == null && grammarArg !== "text", + title: "Grammar", + value: grammarArg + }, { + show: environmentFunctions, + title: "Environment functions", + value: "enabled" + }] + }); + // this is for ora to not interfere with readline await new Promise(resolve => setTimeout(resolve, 1)); @@ -467,8 +546,23 @@ async function RunChat({ process.stdout.write(endColor); console.log(); - if (printTimings) + if (printTimings) { + if (LlamaLogLevelGreaterThan(llama.logLevel, LlamaLogLevel.info)) + llama.logLevel = LlamaLogLevel.info; + await context.printTimings(); + await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing + + llama.logLevel = llamaLogLevel; + } + + if (meter) { + const newTokenMeterState = contextSequence.tokenMeter.getState(); + const tokenMeterDiff = TokenMeter.diff(newTokenMeterState, lastTokenMeterState); + lastTokenMeterState = newTokenMeterState; + + console.info(`${chalk.dim("Input tokens:")} ${String(tokenMeterDiff.usedInputTokens).padEnd(5, " ")} ${chalk.dim("Output tokens:")} ${tokenMeterDiff.usedOutputTokens}`); + } } } diff --git a/src/cli/commands/CompleteCommand.ts b/src/cli/commands/CompleteCommand.ts index db184cc6..cddd50de 100644 --- a/src/cli/commands/CompleteCommand.ts +++ b/src/cli/commands/CompleteCommand.ts @@ -4,17 +4,21 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; -import withStatusLogs from "../../utils/withStatusLogs.js"; +import bytes from "bytes"; import {getLlama} from "../../bindings/getLlama.js"; -import {LlamaLogLevel} from "../../bindings/types.js"; +import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; +import withOra from "../../utils/withOra.js"; +import {TokenMeter} from "../../evaluator/TokenMeter.js"; +import {printInfoLine} from "../utils/printInfoLine.js"; +import {getPrettyBuildGpuName} from "../../bindings/consts.js"; type CompleteCommand = { model: string, systemInfo: boolean, text?: string, textFile?: string, - contextSize: number, + contextSize?: number, batchSize?: number, threads: number, temperature: number, @@ -28,20 +32,21 @@ type CompleteCommand = { repeatFrequencyPenalty?: number, repeatPresencePenalty?: number, maxTokens: number, - noInfoLog: boolean, + debug: boolean, + meter: boolean, printTimings: boolean }; export const CompleteCommand: CommandModule = { - command: "complete", + command: "complete [modelPath]", describe: "Generate a completion for a given text", builder(yargs) { return yargs .option("model", { - alias: "m", + alias: ["m", "modelPath"], type: "string", demandOption: true, - description: "Llama model file to use for the chat", + description: "Llama model file to use for the completion", group: "Required:" }) .option("systemInfo", { @@ -64,8 +69,8 @@ export const CompleteCommand: CommandModule = { .option("contextSize", { alias: "c", type: "number", - default: 1024 * 4, description: "Context size to use for the model context", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("batchSize", { @@ -112,6 +117,7 @@ export const CompleteCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("repeatPenalty", { @@ -154,11 +160,17 @@ export const CompleteCommand: CommandModule = { description: "Maximum number of tokens to generate in responses. Set to `0` to disable. Set to `-1` to set to the context size", group: "Optional:" }) - .option("noInfoLog", { - alias: "nl", + .option("debug", { + alias: "d", type: "boolean", default: false, - description: "Disable llama.cpp info logs", + description: "Print llama.cpp info and debug logs", + group: "Optional:" + }) + .option("meter", { + type: "boolean", + default: false, + description: "Log how many tokens were used as input and output for each response", group: "Optional:" }) .option("printTimings", { @@ -174,14 +186,14 @@ export const CompleteCommand: CommandModule = { threads, temperature, minP, topK, topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, - noInfoLog, printTimings + debug, meter, printTimings }) { try { await RunCompletion({ model, systemInfo, text, textFile, contextSize, batchSize, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, - noInfoLog, printTimings + debug, meter, printTimings }); } catch (err) { await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing @@ -196,15 +208,16 @@ async function RunCompletion({ model: modelArg, systemInfo, text, textFile, contextSize, batchSize, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, - maxTokens, noInfoLog, printTimings + maxTokens, debug, meter, printTimings }: CompleteCommand) { - if (noInfoLog) - console.info(`${chalk.yellow("Log level:")} warn`); + if (debug) + console.info(`${chalk.yellow("Log level:")} debug`); + const llamaLogLevel = debug + ? LlamaLogLevel.debug + : LlamaLogLevel.warn; const llama = await getLlama("lastBuild", { - logLevel: noInfoLog - ? LlamaLogLevel.warn - : LlamaLogLevel.debug + logLevel: llamaLogLevel }); const logBatchSize = batchSize != null; @@ -218,18 +231,17 @@ async function RunCompletion({ text = await fs.readFile(path.resolve(process.cwd(), textFile), "utf8"); } - if (batchSize == null) - batchSize = contextSize; - else if (batchSize > contextSize) { + if (batchSize != null && contextSize != null && batchSize > contextSize) { console.warn(chalk.yellow("Batch size is greater than the context size. Batch size will be set to the context size.")); batchSize = contextSize; } let initialText = text ?? null; - const model = await withStatusLogs({ + const model = await withOra({ loading: chalk.blue("Loading model"), success: chalk.blue("Model loaded"), - fail: chalk.blue("Failed to load model") + fail: chalk.blue("Failed to load model"), + useStatusLogs: debug }, async () => { try { return await llama.loadModel({ @@ -243,15 +255,16 @@ async function RunCompletion({ } } }); - const context = await withStatusLogs({ + const context = await withOra({ loading: chalk.blue("Creating context"), success: chalk.blue("Context created"), - fail: chalk.blue("Failed to create context") + fail: chalk.blue("Failed to create context"), + useStatusLogs: debug }, async () => { try { return await model.createContext({ - contextSize, - batchSize, + contextSize: contextSize != null ? contextSize : undefined, + batchSize: batchSize != null ? batchSize : undefined, threads }); } finally { @@ -262,29 +275,86 @@ async function RunCompletion({ } }); + const contextSequence = context.getSequence(); const completion = new LlamaCompletion({ - contextSequence: context.getSequence() + contextSequence }); + let lastTokenMeterState = contextSequence.tokenMeter.getState(); await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing - console.info(`${chalk.yellow("Context size:")} ${context.contextSize}`); - - if (logBatchSize) - console.info(`${chalk.yellow("Batch size:")} ${context.batchSize}`); - - console.info(`${chalk.yellow("Train context size:")} ${model.trainContextSize}`); - console.info(`${chalk.yellow("Model type:")} ${model.typeDescription}`); - console.info(`${chalk.yellow("Repeat penalty:")} ${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)`); - - if (repeatFrequencyPenalty != null) - console.info(`${chalk.yellow("Repeat frequency penalty:")} ${repeatFrequencyPenalty}`); - - if (repeatPresencePenalty != null) - console.info(`${chalk.yellow("Repeat presence penalty:")} ${repeatPresencePenalty}`); - - if (!penalizeRepeatingNewLine) - console.info(`${chalk.yellow("Penalize repeating new line:")} disabled`); + const padTitle = "Complete".length + 1; + if (llama.gpu !== false) { + printInfoLine({ + title: "GPU", + padTitle: padTitle, + info: [{ + title: "Type", + value: getPrettyBuildGpuName(llama.gpu) + }, { + title: "VRAM", + value: bytes(llama.getVramState().total) + }, { + title: "Name", + value: llama.getGpuDeviceNames().join(", ") + }, { + title: "GPU layers", + value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ + chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) + }` + }] + }); + } + printInfoLine({ + title: "Model", + padTitle: padTitle, + info: [{ + title: "Type", + value: model.typeDescription + }, { + title: "Size", + value: bytes(model.size) + }, { + title: "Train context size", + value: String(model.trainContextSize) + }] + }); + printInfoLine({ + title: "Context", + padTitle: padTitle, + info: [{ + title: "Size", + value: String(context.contextSize) + }, { + show: logBatchSize, + title: "Batch size", + value: bytes(context.batchSize) + }, { + show: meter, + title: "Token meter", + value: "enabled" + }] + }); + printInfoLine({ + title: "Complete", + padTitle: padTitle, + info: [{ + title: "Repeat penalty", + value: `${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)` + }, { + show: repeatFrequencyPenalty != null, + title: "Repeat frequency penalty", + value: String(repeatFrequencyPenalty) + }, { + show: repeatPresencePenalty != null, + title: "Repeat presence penalty", + value: String(repeatPresencePenalty) + }, { + show: !penalizeRepeatingNewLine, + title: "Penalize repeating new line", + value: "disabled" + }] + }); // this is for ora to not interfere with readline await new Promise(resolve => setTimeout(resolve, 1)); @@ -348,7 +418,22 @@ async function RunCompletion({ process.stdout.write(endColor); console.log(); - if (printTimings) + if (printTimings) { + if (LlamaLogLevelGreaterThan(llama.logLevel, LlamaLogLevel.info)) + llama.logLevel = LlamaLogLevel.info; + await context.printTimings(); + await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing + + llama.logLevel = llamaLogLevel; + } + + if (meter) { + const newTokenMeterState = contextSequence.tokenMeter.getState(); + const tokenMeterDiff = TokenMeter.diff(newTokenMeterState, lastTokenMeterState); + lastTokenMeterState = newTokenMeterState; + + console.info(`${chalk.dim("Input tokens:")} ${String(tokenMeterDiff.usedInputTokens).padEnd(5, " ")} ${chalk.dim("Output tokens:")} ${tokenMeterDiff.usedOutputTokens}`); + } } } diff --git a/src/cli/commands/InfillCommand.ts b/src/cli/commands/InfillCommand.ts index 00df32a0..acd2203f 100644 --- a/src/cli/commands/InfillCommand.ts +++ b/src/cli/commands/InfillCommand.ts @@ -4,10 +4,14 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; -import withStatusLogs from "../../utils/withStatusLogs.js"; +import bytes from "bytes"; import {getLlama} from "../../bindings/getLlama.js"; -import {LlamaLogLevel} from "../../bindings/types.js"; +import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; +import withOra from "../../utils/withOra.js"; +import {TokenMeter} from "../../evaluator/TokenMeter.js"; +import {printInfoLine} from "../utils/printInfoLine.js"; +import {getPrettyBuildGpuName} from "../../bindings/consts.js"; type InfillCommand = { model: string, @@ -16,7 +20,7 @@ type InfillCommand = { prefixFile?: string, suffix?: string, suffixFile?: string, - contextSize: number, + contextSize?: number, batchSize?: number, threads: number, temperature: number, @@ -30,20 +34,21 @@ type InfillCommand = { repeatFrequencyPenalty?: number, repeatPresencePenalty?: number, maxTokens: number, - noInfoLog: boolean, + debug: boolean, + meter: boolean, printTimings: boolean }; export const InfillCommand: CommandModule = { - command: "infill", + command: "infill [modelPath]", describe: "Generate an infill completion for a given suffix and prefix texts", builder(yargs) { return yargs .option("model", { - alias: "m", + alias: ["m", "modelPath"], type: "string", demandOption: true, - description: "Llama model file to use for the chat", + description: "Llama model file to use for the infill", group: "Required:" }) .option("systemInfo", { @@ -76,8 +81,8 @@ export const InfillCommand: CommandModule = { .option("contextSize", { alias: "c", type: "number", - default: 1024 * 4, description: "Context size to use for the model context", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("batchSize", { @@ -124,6 +129,7 @@ export const InfillCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) .option("repeatPenalty", { @@ -166,11 +172,17 @@ export const InfillCommand: CommandModule = { description: "Maximum number of tokens to generate in responses. Set to `0` to disable. Set to `-1` to set to the context size", group: "Optional:" }) - .option("noInfoLog", { - alias: "nl", + .option("debug", { + alias: "d", type: "boolean", default: false, - description: "Disable llama.cpp info logs", + description: "Print llama.cpp info and debug logs", + group: "Optional:" + }) + .option("meter", { + type: "boolean", + default: false, + description: "Log how many tokens were used as input and output for each response", group: "Optional:" }) .option("printTimings", { @@ -186,14 +198,14 @@ export const InfillCommand: CommandModule = { threads, temperature, minP, topK, topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, - noInfoLog, printTimings + debug, meter, printTimings }) { try { await RunInfill({ model, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, - noInfoLog, printTimings + debug, meter, printTimings }); } catch (err) { await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing @@ -208,15 +220,16 @@ async function RunInfill({ model: modelArg, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, - maxTokens, noInfoLog, printTimings + maxTokens, debug, meter, printTimings }: InfillCommand) { - if (noInfoLog) - console.info(`${chalk.yellow("Log level:")} warn`); + if (debug) + console.info(`${chalk.yellow("Log level:")} debug`); + const llamaLogLevel = debug + ? LlamaLogLevel.debug + : LlamaLogLevel.warn; const llama = await getLlama("lastBuild", { - logLevel: noInfoLog - ? LlamaLogLevel.warn - : LlamaLogLevel.debug + logLevel: llamaLogLevel }); const logBatchSize = batchSize != null; @@ -242,9 +255,7 @@ async function RunInfill({ suffix = undefined; } - if (batchSize == null) - batchSize = contextSize; - else if (batchSize > contextSize) { + if (batchSize != null && contextSize != null && batchSize > contextSize) { console.warn(chalk.yellow("Batch size is greater than the context size. Batch size will be set to the context size.")); batchSize = contextSize; } @@ -252,10 +263,11 @@ async function RunInfill({ let initialPrefix = prefix ?? null; let initialSuffix = suffix ?? null; - const model = await withStatusLogs({ + const model = await withOra({ loading: chalk.blue("Loading model"), success: chalk.blue("Model loaded"), - fail: chalk.blue("Failed to load model") + fail: chalk.blue("Failed to load model"), + useStatusLogs: debug }, async () => { try { return await llama.loadModel({ @@ -269,15 +281,16 @@ async function RunInfill({ } } }); - const context = await withStatusLogs({ + const context = await withOra({ loading: chalk.blue("Creating context"), success: chalk.blue("Context created"), - fail: chalk.blue("Failed to create context") + fail: chalk.blue("Failed to create context"), + useStatusLogs: debug }, async () => { try { return await model.createContext({ - contextSize, - batchSize, + contextSize: contextSize != null ? contextSize : undefined, + batchSize: batchSize != null ? batchSize : undefined, threads }); } finally { @@ -288,29 +301,86 @@ async function RunInfill({ } }); + const contextSequence = context.getSequence(); const completion = new LlamaCompletion({ - contextSequence: context.getSequence() + contextSequence }); + let lastTokenMeterState = contextSequence.tokenMeter.getState(); await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing - console.info(`${chalk.yellow("Context size:")} ${context.contextSize}`); - - if (logBatchSize) - console.info(`${chalk.yellow("Batch size:")} ${context.batchSize}`); - - console.info(`${chalk.yellow("Train context size:")} ${model.trainContextSize}`); - console.info(`${chalk.yellow("Model type:")} ${model.typeDescription}`); - console.info(`${chalk.yellow("Repeat penalty:")} ${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)`); - - if (repeatFrequencyPenalty != null) - console.info(`${chalk.yellow("Repeat frequency penalty:")} ${repeatFrequencyPenalty}`); - - if (repeatPresencePenalty != null) - console.info(`${chalk.yellow("Repeat presence penalty:")} ${repeatPresencePenalty}`); - - if (!penalizeRepeatingNewLine) - console.info(`${chalk.yellow("Penalize repeating new line:")} disabled`); + const padTitle = "Context".length + 1; + if (llama.gpu !== false) { + printInfoLine({ + title: "GPU", + padTitle: padTitle, + info: [{ + title: "Type", + value: getPrettyBuildGpuName(llama.gpu) + }, { + title: "VRAM", + value: bytes(llama.getVramState().total) + }, { + title: "Name", + value: llama.getGpuDeviceNames().join(", ") + }, { + title: "GPU layers", + value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ + chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) + }` + }] + }); + } + printInfoLine({ + title: "Model", + padTitle: padTitle, + info: [{ + title: "Type", + value: model.typeDescription + }, { + title: "Size", + value: bytes(model.size) + }, { + title: "Train context size", + value: String(model.trainContextSize) + }] + }); + printInfoLine({ + title: "Context", + padTitle: padTitle, + info: [{ + title: "Size", + value: String(context.contextSize) + }, { + show: logBatchSize, + title: "Batch size", + value: bytes(context.batchSize) + }, { + show: meter, + title: "Token meter", + value: "enabled" + }] + }); + printInfoLine({ + title: "Infill", + padTitle: padTitle, + info: [{ + title: "Repeat penalty", + value: `${repeatPenalty} (apply to last ${lastTokensRepeatPenalty} tokens)` + }, { + show: repeatFrequencyPenalty != null, + title: "Repeat frequency penalty", + value: String(repeatFrequencyPenalty) + }, { + show: repeatPresencePenalty != null, + title: "Repeat presence penalty", + value: String(repeatPresencePenalty) + }, { + show: !penalizeRepeatingNewLine, + title: "Penalize repeating new line", + value: "disabled" + }] + }); // this is for ora to not interfere with readline await new Promise(resolve => setTimeout(resolve, 1)); @@ -395,7 +465,22 @@ async function RunInfill({ process.stdout.write(endColor); console.log(); - if (printTimings) + if (printTimings) { + if (LlamaLogLevelGreaterThan(llama.logLevel, LlamaLogLevel.info)) + llama.logLevel = LlamaLogLevel.info; + await context.printTimings(); + await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing + + llama.logLevel = llamaLogLevel; + } + + if (meter) { + const newTokenMeterState = contextSequence.tokenMeter.getState(); + const tokenMeterDiff = TokenMeter.diff(newTokenMeterState, lastTokenMeterState); + lastTokenMeterState = newTokenMeterState; + + console.info(`${chalk.dim("Input tokens:")} ${String(tokenMeterDiff.usedInputTokens).padEnd(5, " ")} ${chalk.dim("Output tokens:")} ${tokenMeterDiff.usedOutputTokens}`); + } } } diff --git a/src/cli/utils/printInfoLine.ts b/src/cli/utils/printInfoLine.ts new file mode 100644 index 00000000..00c53935 --- /dev/null +++ b/src/cli/utils/printInfoLine.ts @@ -0,0 +1,66 @@ +import chalk from "chalk"; +import stripAnsi from "strip-ansi"; + +export function printInfoLine({ + title, padTitle = 0, separateLines = false, info +}: { + title?: string, + padTitle?: number, + separateLines?: boolean, + info: Array<{ + title: string, + value: string | (() => string), + show?: boolean + }> +}) { + const res: string[] = []; + const items: string[] = []; + if (separateLines) { + if (title != null && title.length > 0) + res.push(chalk.yellowBright(`${title.trim()}`)); + + for (const {title, value, show} of info) { + if (show === false) + continue; + + items.push(`${chalk.yellow(title + ":")} ${value instanceof Function ? value() : value}`); + } + + const itemPrefix = `${chalk.dim("|")} `; + res.push(itemPrefix + items.join("\n" + itemPrefix)); + console.info(res.join("\n") + "\n"); + } else { + if (title != null && title.length > 0) + res.push(chalk.yellowBright(`${title.padEnd(padTitle, " ")}`)); + + for (const {title, value, show} of info) { + if (show === false) + continue; + + items.push(chalk.bgGray(` ${chalk.yellow(title + ":")} ${value instanceof Function ? value() : value} `)); + } + + const startPad = stripAnsi(res.join(" ")).length + (res.length > 0 ? " ".length : 0); + res.push(splitItemsIntoLines(items, process.stdout.columns - 1 - startPad).join("\n" + " ".repeat(startPad))); + console.info(res.join(" ")); + } +} + +function splitItemsIntoLines(items: string[], maxLineLength: number) { + const lines: string[] = []; + let currentLine: string[] = []; + + for (const item of items) { + if (stripAnsi([...currentLine, item].join(" ")).length > maxLineLength) { + lines.push(currentLine.join(" ")); + currentLine = []; + } + + currentLine.push(item); + } + + if (currentLine.length > 0) + lines.push(currentLine.join(" ")); + + return lines; +} diff --git a/src/utils/withOra.ts b/src/utils/withOra.ts index 38ce4443..b7093abc 100644 --- a/src/utils/withOra.ts +++ b/src/utils/withOra.ts @@ -7,11 +7,12 @@ export default async function withOra( message: string | { loading: string, success?: string, - fail?: string + fail?: string, + useStatusLogs?: boolean }, callback: () => Promise ): Promise { - if (isRunningInsideGoogleColab) + if (isRunningInsideGoogleColab || (typeof message !== "string" && message.useStatusLogs)) return withStatusLogs(message, callback); const spinner = ora({ @@ -19,7 +20,7 @@ export default async function withOra( ...( typeof message === "string" ? {text: message} satisfies Parameters[0] - : message + : {loading: message.loading, success: message.success, fail: message.fail} ) }); From cc5cb5bce309692065d9f3f41201a7e485360f6e Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 01:38:01 +0200 Subject: [PATCH 25/52] docs: update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d4e7bbaf..2fbe3f6e 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ ## Features * Run a text generation model locally on your machine * Metal, CUDA and Vulkan support -* Pre-built binaries are provided, with a fallback to building from source without `node-gyp` or Python +* Pre-built binaries are provided, with a fallback to building from source _**without**_ `node-gyp` or Python * Chat with a model using a chat wrapper * Use the CLI to chat with a model without writing any code * Up-to-date with the latest version of `llama.cpp`. Download and compile the latest release with a single CLI command. From d50f3a4496d41c30af4953e1e095ecc71f88542d Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 01:59:38 +0200 Subject: [PATCH 26/52] fix: CUDA GPU info --- llama/gpuInfo/cuda-gpu-info.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/llama/gpuInfo/cuda-gpu-info.cu b/llama/gpuInfo/cuda-gpu-info.cu index d681b46f..1b0a21d9 100644 --- a/llama/gpuInfo/cuda-gpu-info.cu +++ b/llama/gpuInfo/cuda-gpu-info.cu @@ -1,5 +1,6 @@ #include #include +#include #if defined(GPU_INFO_USE_HIPBLAS) #include From 7342633a24c840fc5dddb5e5a6f95bad3efe05b8 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 03:03:54 +0300 Subject: [PATCH 27/52] fix: CUDA GPU info --- llama/gpuInfo/cuda-gpu-info.h | 1 + 1 file changed, 1 insertion(+) diff --git a/llama/gpuInfo/cuda-gpu-info.h b/llama/gpuInfo/cuda-gpu-info.h index 570bdc69..e77b6f29 100644 --- a/llama/gpuInfo/cuda-gpu-info.h +++ b/llama/gpuInfo/cuda-gpu-info.h @@ -2,6 +2,7 @@ #include #include +#include typedef void (*gpuInfoCudaErrorLogCallback_t)(const char* message); From f6ca540dfce92d2b952eecc63546a9b304843cad Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 03:05:00 +0300 Subject: [PATCH 28/52] fix: CUDA GPU info --- llama/gpuInfo/cuda-gpu-info.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama/gpuInfo/cuda-gpu-info.cu b/llama/gpuInfo/cuda-gpu-info.cu index 1b0a21d9..1559fc0b 100644 --- a/llama/gpuInfo/cuda-gpu-info.cu +++ b/llama/gpuInfo/cuda-gpu-info.cu @@ -114,7 +114,7 @@ void gpuInfoGetCudaDeviceNames(std::vector * deviceNames, gpuInfoCu if (getDevicePropertiesResult != cudaSuccess) { errorLogCallback(cudaGetErrorString(getDevicePropertiesResult)); } else { - (*deviceNames)->push_back(std::string(prop.name)); + (*deviceNames).push_back(std::string(prop.name)); } } } From 98d0b888382a428dab7a557fd08233cef9cf0e96 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 03:15:19 +0300 Subject: [PATCH 29/52] fix: use the CUDA integration instead of the deprecated cuBLAS integration --- llama/CMakeLists.txt | 10 +++++----- llama/addon.cpp | 10 +++++----- src/bindings/utils/compileLLamaCpp.ts | 4 ++-- src/bindings/utils/resolveCustomCmakeOptions.ts | 2 +- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/llama/CMakeLists.txt b/llama/CMakeLists.txt index 5a67f085..1e3bb759 100644 --- a/llama/CMakeLists.txt +++ b/llama/CMakeLists.txt @@ -29,19 +29,19 @@ include_directories("gpuInfo") include_directories("llama.cpp") include_directories("./llama.cpp/common") -if (LLAMA_CUBLAS) +if (LLAMA_CUDA) cmake_minimum_required(VERSION 3.17) find_package(CUDAToolkit) if (CUDAToolkit_FOUND) - message(STATUS "Using cuBLAS for GPU info") + message(STATUS "Using CUDA for GPU info") enable_language(CUDA) set(GPU_INFO_HEADERS ${GPU_INFO_HEADERS} gpuInfo/cuda-gpu-info.h) set(GPU_INFO_SOURCES ${GPU_INFO_SOURCES} gpuInfo/cuda-gpu-info.cu) - add_compile_definitions(GPU_INFO_USE_CUBLAS) + add_compile_definitions(GPU_INFO_USE_CUDA) if (LLAMA_STATIC) set(LLAMA_EXTRA_LIBS ${GPU_INFO_EXTRA_LIBS} CUDA::cudart_static) @@ -60,7 +60,7 @@ if (LLAMA_CUBLAS) endif() endif() else() - message(FATAL_ERROR "cuBLAS was not found") + message(FATAL_ERROR "CUDA was not found") endif() endif() @@ -100,7 +100,7 @@ if (LLAMA_HIPBLAS) if (${hipblas_FOUND} AND ${hip_FOUND}) message(STATUS "Using HIP and hipBLAS for GPU info") - add_compile_definitions(GPU_INFO_USE_HIPBLAS GPU_INFO_USE_CUBLAS) + add_compile_definitions(GPU_INFO_USE_HIPBLAS GPU_INFO_USE_CUDA) add_library(gpu-info-rocm OBJECT gpuInfo/cuda-gpu-info.cu gpuInfo/cuda-gpu-info.h) set_source_files_properties(gpuInfo/cuda-gpu-info.cu PROPERTIES LANGUAGE CXX) target_link_libraries(gpu-info-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) diff --git a/llama/addon.cpp b/llama/addon.cpp index 273adfca..f9da4a49 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -9,7 +9,7 @@ #include "llama.h" #include "napi.h" -#ifdef GPU_INFO_USE_CUBLAS +#ifdef GPU_INFO_USE_CUDA # include "gpuInfo/cuda-gpu-info.h" #endif #ifdef GPU_INFO_USE_VULKAN @@ -121,7 +121,7 @@ std::string addon_model_token_to_piece(const struct llama_model* model, llama_to return std::string(result.data(), result.size()); } -#ifdef GPU_INFO_USE_CUBLAS +#ifdef GPU_INFO_USE_CUDA void logCudaError(const char* message) { addonLlamaCppLogCallback(GGML_LOG_LEVEL_ERROR, (std::string("CUDA error: ") + std::string(message)).c_str(), nullptr); } @@ -136,7 +136,7 @@ Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { uint64_t total = 0; uint64_t used = 0; -#ifdef GPU_INFO_USE_CUBLAS +#ifdef GPU_INFO_USE_CUDA size_t cudaDeviceTotal = 0; size_t cudaDeviceUsed = 0; bool cudeGetInfoSuccess = gpuInfoGetTotalCudaDevicesInfo(&cudaDeviceTotal, &cudaDeviceUsed, logCudaError); @@ -177,7 +177,7 @@ Napi::Value getGpuVramInfo(const Napi::CallbackInfo& info) { Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info) { std::vector deviceNames; -#ifdef GPU_INFO_USE_CUBLAS +#ifdef GPU_INFO_USE_CUDA gpuInfoGetCudaDeviceNames(&deviceNames, logCudaError); #endif @@ -201,7 +201,7 @@ Napi::Value getGpuDeviceInfo(const Napi::CallbackInfo& info) { } Napi::Value getGpuType(const Napi::CallbackInfo& info) { -#ifdef GPU_INFO_USE_CUBLAS +#ifdef GPU_INFO_USE_CUDA return Napi::String::New(info.Env(), "cuda"); #endif diff --git a/src/bindings/utils/compileLLamaCpp.ts b/src/bindings/utils/compileLLamaCpp.ts index 59d42926..6abb2f8b 100644 --- a/src/bindings/utils/compileLLamaCpp.ts +++ b/src/bindings/utils/compileLLamaCpp.ts @@ -73,8 +73,8 @@ export async function compileLlamaCpp(buildOptions: BuildOptions, compileOptions else if (!cmakeCustomOptions.has("LLAMA_METAL")) cmakeCustomOptions.set("LLAMA_METAL", "OFF"); - if (buildOptions.gpu === "cuda" && !cmakeCustomOptions.has("LLAMA_CUBLAS")) - cmakeCustomOptions.set("LLAMA_CUBLAS", "1"); + if (buildOptions.gpu === "cuda" && !cmakeCustomOptions.has("LLAMA_CUDA")) + cmakeCustomOptions.set("LLAMA_CUDA", "1"); if (buildOptions.gpu === "vulkan" && !cmakeCustomOptions.has("LLAMA_VULKAN")) cmakeCustomOptions.set("LLAMA_VULKAN", "1"); diff --git a/src/bindings/utils/resolveCustomCmakeOptions.ts b/src/bindings/utils/resolveCustomCmakeOptions.ts index f9a11006..1e56d53b 100644 --- a/src/bindings/utils/resolveCustomCmakeOptions.ts +++ b/src/bindings/utils/resolveCustomCmakeOptions.ts @@ -7,7 +7,7 @@ export function resolveCustomCmakeOptions(customCmakeOptions?: Record Date: Fri, 29 Mar 2024 03:29:07 +0300 Subject: [PATCH 30/52] fix: Vulkan GPU info --- llama/gpuInfo/vulkan-gpu-info.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/llama/gpuInfo/vulkan-gpu-info.cpp b/llama/gpuInfo/vulkan-gpu-info.cpp index 4bdbbe7d..0b9a6556 100644 --- a/llama/gpuInfo/vulkan-gpu-info.cpp +++ b/llama/gpuInfo/vulkan-gpu-info.cpp @@ -5,17 +5,6 @@ typedef void (*gpuInfoVulkanWarningLogCallback_t)(const char* message); -bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { - return enumerateVulkanDevices(total, used, false, nullptr, warningLogCallback); -} - -bool gpuInfoGetVulkanDeviceNames(std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { - size_t vulkanDeviceTotal = 0; - size_t vulkanDeviceUsed = 0; - - return enumerateVulkanDevices(&vulkanDeviceTotal, &vulkanDeviceUsed, true, deviceNames, warningLogCallback); -} - static bool enumerateVulkanDevices(size_t* total, size_t* used, bool addDeviceNames, std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { vk::ApplicationInfo appInfo("node-llama-cpp GPU info", 1, "llama.cpp", 1, VK_API_VERSION_1_2); vk::InstanceCreateInfo createInfo(vk::InstanceCreateFlags(), &appInfo, {}, {}); @@ -81,3 +70,14 @@ static bool enumerateVulkanDevices(size_t* total, size_t* used, bool addDeviceNa *used = usedMem; return true; } + +bool gpuInfoGetTotalVulkanDevicesInfo(size_t* total, size_t* used, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { + return enumerateVulkanDevices(total, used, false, nullptr, warningLogCallback); +} + +bool gpuInfoGetVulkanDeviceNames(std::vector * deviceNames, gpuInfoVulkanWarningLogCallback_t warningLogCallback) { + size_t vulkanDeviceTotal = 0; + size_t vulkanDeviceUsed = 0; + + return enumerateVulkanDevices(&vulkanDeviceTotal, &vulkanDeviceUsed, true, deviceNames, warningLogCallback); +} From e9fe2080e19048478be415031b30c86d04a5d069 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 03:46:02 +0300 Subject: [PATCH 31/52] test: fix snapshots --- .../__snapshots__/ggufParser.test.ts.snap | 30 +++++++++++++++++++ .../ggufStandaloneParser.test.ts.snap | 24 +++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap b/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap index 3d356afc..6d2fc082 100644 --- a/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap +++ b/test/modelDependent/functionary/gguf/__snapshots__/ggufParser.test.ts.snap @@ -2,6 +2,21 @@ exports[`gguf > parser > should fetch GGUF metadata 1`] = ` { + "architectureMetadata": { + "attention": { + "head_count": 32, + "head_count_kv": 8, + "layer_norm_rms_epsilon": 0.000009999999747378752, + }, + "block_count": 32, + "context_length": 32768, + "embedding_length": 4096, + "feed_forward_length": 14336, + "rope": { + "dimension_count": 128, + "freq_base": 10000, + }, + }, "metadata": { "general": { "architecture": "llama", @@ -148,6 +163,21 @@ exports[`gguf > parser > should fetch GGUF metadata 1`] = ` exports[`gguf > parser > should parse local gguf model 1`] = ` { + "architectureMetadata": { + "attention": { + "head_count": 32, + "head_count_kv": 8, + "layer_norm_rms_epsilon": 0.000009999999747378752, + }, + "block_count": 32, + "context_length": 32768, + "embedding_length": 4096, + "feed_forward_length": 14336, + "rope": { + "dimension_count": 128, + "freq_base": 10000, + }, + }, "metadata": { "general": { "architecture": "llama", diff --git a/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap b/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap index 0c07ab18..69f38dd1 100644 --- a/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap +++ b/test/standalone/gguf/__snapshots__/ggufStandaloneParser.test.ts.snap @@ -2,6 +2,18 @@ exports[`gguf > parser > should parse remote gguf model 1`] = ` { + "architectureMetadata": { + "attention": { + "head_count": 232, + "head_count_kv": 8, + "layer_norm_epsilon": 0.000009999999747378752, + }, + "block_count": 80, + "context_length": 2048, + "embedding_length": 14848, + "feed_forward_length": 59392, + "tensor_data_layout": "jploski", + }, "metadata": { "falcon": { "attention": { @@ -123,6 +135,18 @@ exports[`gguf > parser > should parse remote gguf model 1`] = ` exports[`gguf > parser > should parse remote gguf model without tensor info 1`] = ` { + "architectureMetadata": { + "attention": { + "head_count": 232, + "head_count_kv": 8, + "layer_norm_epsilon": 0.000009999999747378752, + }, + "block_count": 80, + "context_length": 2048, + "embedding_length": 14848, + "feed_forward_length": 59392, + "tensor_data_layout": "jploski", + }, "metadata": { "falcon": { "attention": { From f49945168a497517fcccd61bb0842a3ec02a4975 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Fri, 29 Mar 2024 04:10:59 +0300 Subject: [PATCH 32/52] test: add a sanity test to validate that an input is tokenized properly with special tokens only when `specialTokens` is enabled --- .../modelDependent/functionary/sanity.test.ts | 69 +++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/test/modelDependent/functionary/sanity.test.ts b/test/modelDependent/functionary/sanity.test.ts index 4204b0d2..6a8e5ca8 100644 --- a/test/modelDependent/functionary/sanity.test.ts +++ b/test/modelDependent/functionary/sanity.test.ts @@ -23,5 +23,74 @@ describe("functionary", () => { expect(res).to.eql("6+6 equals 12."); }); + + test("text is tokenized with special tokens when appropriate", {timeout: 1000 * 60 * 60 * 2}, async () => { + const modelPath = await getModelFile("functionary-small-v2.2.q4_0.gguf"); + const llama = await getTestLlama(); + + const model = await llama.loadModel({ + modelPath + }); + + const text = "<|from|>system\n<|recipient|>all\n<|content|>How much is 6+6\n"; + + const tokensWithSpecialTokens = model.tokenize(text, true); + const tokensWithoutSpecialTokens = model.tokenize(text); + + expect(tokensWithSpecialTokens).to.not.eql(tokensWithoutSpecialTokens); + + expect(tokensWithSpecialTokens).to.toMatchInlineSnapshot(` + [ + 32002, + 6574, + 13, + 32001, + 455, + 13, + 32000, + 5660, + 1188, + 349, + 28705, + 28784, + 28806, + 28784, + 13, + ] + `); + expect(tokensWithoutSpecialTokens).to.toMatchInlineSnapshot(` + [ + 523, + 28766, + 3211, + 28766, + 28767, + 6574, + 13, + 28789, + 28766, + 3354, + 508, + 722, + 28766, + 28767, + 455, + 13, + 28789, + 28766, + 3789, + 28766, + 28767, + 5660, + 1188, + 349, + 28705, + 28784, + 28806, + 28784, + 13, + ] + `); + }); }); }); From 2e324b2ca5f66a7b3cfdb150c49fe689862d5af8 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 20:09:59 +0300 Subject: [PATCH 33/52] feat(`JinjaTemplateChatWrapper`): use `tokenizer.chat_template` from a model metadata when possible Detect when an equivalent specialized chat wrapper is available and use it instead. --- llama/addon.cpp | 11 + package-lock.json | 9 + package.json | 1 + src/ChatWrapper.ts | 7 + src/bindings/AddonTypes.ts | 1 + src/bindings/types.ts | 13 + ...esolveChatWrapperBasedOnWrapperTypeName.ts | 74 -- src/chatWrappers/AlpacaChatWrapper.ts | 15 +- src/chatWrappers/ChatMLChatWrapper.ts | 4 +- src/chatWrappers/FalconChatWrapper.ts | 42 +- src/chatWrappers/FunctionaryChatWrapper.ts | 95 ++- src/chatWrappers/GemmaChatWrapper.ts | 3 +- src/chatWrappers/GeneralChatWrapper.ts | 49 +- src/chatWrappers/LlamaChatWrapper.ts | 29 +- .../generic/JinjaTemplateChatWrapper.ts | 458 +++++++++++ .../generic}/TemplateChatWrapper.ts | 110 +-- .../chatHistoryFunctionCallMessageTemplate.ts | 77 ++ .../resolveChatWrapperBasedOnModel.ts | 79 -- ...plateEquivalentToSpecializedChatWrapper.ts | 249 ++++++ src/chatWrappers/utils/resolveChatWrapper.ts | 249 ++++++ src/cli/commands/ChatCommand.ts | 19 +- .../inspect/commands/InspectGgufCommand.ts | 3 +- src/evaluator/LlamaChat/LlamaChat.ts | 14 +- src/evaluator/LlamaModel.ts | 57 +- src/gguf/readGgufFileInfo.ts | 3 +- src/gguf/types/GgufMetadataTypes.ts | 10 +- src/gguf/utils/normalizeGgufDownloadUrl.ts | 20 + src/index.ts | 61 +- src/types.ts | 2 +- src/utils/LlamaText.ts | 201 ++++- src/utils/resolveChatWrapper.ts | 21 - .../generic/JinjaTemplateChatWrapper.test.ts | 753 ++++++++++++++++++ .../LlamaChatPromptWrapper.test.ts | 4 +- .../utils/resolveChatWrapper.test.ts | 207 +++++ test/utils/helpers/llamaTextSerializer.ts | 11 + vitest.config.ts | 3 +- 36 files changed, 2599 insertions(+), 365 deletions(-) delete mode 100644 src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts create mode 100644 src/chatWrappers/generic/JinjaTemplateChatWrapper.ts rename src/{ => chatWrappers/generic}/TemplateChatWrapper.ts (79%) create mode 100644 src/chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.ts delete mode 100644 src/chatWrappers/resolveChatWrapperBasedOnModel.ts create mode 100644 src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts create mode 100644 src/chatWrappers/utils/resolveChatWrapper.ts create mode 100644 src/gguf/utils/normalizeGgufDownloadUrl.ts delete mode 100644 src/utils/resolveChatWrapper.ts create mode 100644 test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts rename test/standalone/chatWrappers/{ => generic}/LlamaChatPromptWrapper.test.ts (97%) create mode 100644 test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts create mode 100644 test/utils/helpers/llamaTextSerializer.ts diff --git a/llama/addon.cpp b/llama/addon.cpp index f9da4a49..4651a6fe 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -533,6 +533,16 @@ class AddonModel : public Napi::ObjectWrap { return Napi::Number::From(info.Env(), int32_t(tokenType)); } + Napi::Value GetVocabularyType(const Napi::CallbackInfo& info) { + if (disposed) { + Napi::Error::New(info.Env(), "Model is disposed").ThrowAsJavaScriptException(); + return info.Env().Undefined(); + } + + auto vocabularyType = llama_vocab_type(model); + + return Napi::Number::From(info.Env(), int32_t(vocabularyType)); + } Napi::Value ShouldPrependBosToken(const Napi::CallbackInfo& info) { const int addBos = llama_add_bos_token(model); @@ -570,6 +580,7 @@ class AddonModel : public Napi::ObjectWrap { InstanceMethod("eotToken", &AddonModel::EotToken), InstanceMethod("getTokenString", &AddonModel::GetTokenString), InstanceMethod("getTokenType", &AddonModel::GetTokenType), + InstanceMethod("getVocabularyType", &AddonModel::GetVocabularyType), InstanceMethod("shouldPrependBosToken", &AddonModel::ShouldPrependBosToken), InstanceMethod("getModelSize", &AddonModel::GetModelSize), InstanceMethod("dispose", &AddonModel::Dispose), diff --git a/package-lock.json b/package-lock.json index e326c592..6b62664d 100644 --- a/package-lock.json +++ b/package-lock.json @@ -9,6 +9,7 @@ "version": "0.1.0", "license": "MIT", "dependencies": { + "@huggingface/jinja": "^0.2.2", "async-retry": "^1.3.3", "bytes": "^3.1.2", "chalk": "^5.3.0", @@ -1428,6 +1429,14 @@ "node": "^12.22.0 || ^14.17.0 || >=16.0.0" } }, + "node_modules/@huggingface/jinja": { + "version": "0.2.2", + "resolved": "https://registry.npmjs.org/@huggingface/jinja/-/jinja-0.2.2.tgz", + "integrity": "sha512-/KPde26khDUIPkTGU82jdtTW9UAuvUTumCAbFs/7giR0SxsvZC4hru51PBvpijH6BVkHcROcvZM/lpy5h1jRRA==", + "engines": { + "node": ">=18" + } + }, "node_modules/@humanwhocodes/config-array": { "version": "0.11.13", "resolved": "https://registry.npmjs.org/@humanwhocodes/config-array/-/config-array-0.11.13.tgz", diff --git a/package.json b/package.json index a2a075e8..36eabb65 100644 --- a/package.json +++ b/package.json @@ -149,6 +149,7 @@ "zx": "^7.2.3" }, "dependencies": { + "@huggingface/jinja": "^0.2.2", "async-retry": "^1.3.3", "bytes": "^3.1.2", "chalk": "^5.3.0", diff --git a/src/ChatWrapper.ts b/src/ChatWrapper.ts index 4e8ae6a0..d8621359 100644 --- a/src/ChatWrapper.ts +++ b/src/ChatWrapper.ts @@ -185,4 +185,11 @@ export abstract class ChatWrapper { return {}; } + + /** @internal */ + public static _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate(): Record[] { + return [{}] satisfies Partial, object>>[]; + } } + +type FirstItemOfTupleOrFallback = T extends [infer U, ...any[]] ? U : Fallback; diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index ed13b1fa..a7d7c075 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -81,6 +81,7 @@ export type AddonModel = { eotToken(): Token, getTokenString(token: number): string, getTokenType(token: Token): number, + getVocabularyType(): number, shouldPrependBosToken(): boolean, getModelSize(): number }; diff --git a/src/bindings/types.ts b/src/bindings/types.ts index 528806a0..204b00a5 100644 --- a/src/bindings/types.ts +++ b/src/bindings/types.ts @@ -78,6 +78,19 @@ export const LlamaLogLevelValues = Object.freeze([ LlamaLogLevel.debug ] as const); +export enum LlamaVocabularyType { + none = "none", + spm = "spm", + bpe = "bpe", + wpm = "wpm", +} +export const LlamaVocabularyTypeValues = Object.freeze([ + LlamaVocabularyType.none, + LlamaVocabularyType.spm, + LlamaVocabularyType.bpe, + LlamaVocabularyType.wpm +] as const); + /** *Check if a log level is higher than another log level */ diff --git a/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts b/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts deleted file mode 100644 index e24e7fd5..00000000 --- a/src/bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.ts +++ /dev/null @@ -1,74 +0,0 @@ -import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; -import {LlamaChatWrapper} from "../../chatWrappers/LlamaChatWrapper.js"; -import {AlpacaChatWrapper} from "../../chatWrappers/AlpacaChatWrapper.js"; -import {FunctionaryChatWrapper} from "../../chatWrappers/FunctionaryChatWrapper.js"; -import {ChatMLChatWrapper} from "../../chatWrappers/ChatMLChatWrapper.js"; -import {FalconChatWrapper} from "../../chatWrappers/FalconChatWrapper.js"; -import {resolveChatWrapperBasedOnModel} from "../../chatWrappers/resolveChatWrapperBasedOnModel.js"; -import {GemmaChatWrapper} from "../../chatWrappers/GemmaChatWrapper.js"; -import {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; - -export const chatWrapperTypeNames = Object.freeze([ - "auto", "general", "llamaChat", "alpacaChat", "functionary", "chatML", "falconChat", "gemma" -] as const); -export type ChatWrapperTypeName = (typeof chatWrapperTypeNames)[number]; - -const chatWrappers = { - "general": GeneralChatWrapper, - "llamaChat": LlamaChatWrapper, - "alpacaChat": AlpacaChatWrapper, - "functionary": FunctionaryChatWrapper, - "chatML": ChatMLChatWrapper, - "falconChat": FalconChatWrapper, - "gemma": GemmaChatWrapper -} as const satisfies Record, any>; -const chatWrapperToConfigType = new Map( - Object.entries(chatWrappers).map(([configType, Wrapper]) => [Wrapper, configType]) -); - -/** - * @param configType - * @param options - */ -export function resolveChatWrapperBasedOnWrapperTypeName(configType: ChatWrapperTypeName, { - bosString, - filename, - fileInfo, - customWrapperSettings -}: { - bosString?: string | null, - filename?: string, - fileInfo?: GgufFileInfo, - customWrapperSettings?: { - [wrapper in keyof typeof chatWrappers]?: ConstructorParameters<(typeof chatWrappers)[wrapper]>[0] - } -} = {}) { - if (Object.hasOwn(chatWrappers, configType)) { - const Wrapper = chatWrappers[configType as keyof typeof chatWrappers]; - const wrapperSettings: ConstructorParameters[0] | undefined = - customWrapperSettings?.[configType as keyof typeof chatWrappers]; - - return new Wrapper(wrapperSettings); - } - - if (configType === "auto") { - const chatWrapper = resolveChatWrapperBasedOnModel({ - bosString, - filename, - fileInfo - }); - - if (chatWrapper != null) { - const resolvedConfigType = chatWrapperToConfigType.get(chatWrapper); - const wrapperSettings: ConstructorParameters[0] | undefined = resolvedConfigType == null - ? undefined - : customWrapperSettings?.[resolvedConfigType as keyof typeof chatWrappers]; - - return new chatWrapper(wrapperSettings); - } - - return new GeneralChatWrapper(customWrapperSettings?.general); - } - - throw new Error("Unknown wrapper config: " + configType); -} diff --git a/src/chatWrappers/AlpacaChatWrapper.ts b/src/chatWrappers/AlpacaChatWrapper.ts index 64c96003..09bb6297 100644 --- a/src/chatWrappers/AlpacaChatWrapper.ts +++ b/src/chatWrappers/AlpacaChatWrapper.ts @@ -4,14 +4,16 @@ export class AlpacaChatWrapper extends GeneralChatWrapper { public override readonly wrapperName: string = "AlpacaChat"; public constructor({ - userMessageTitle = "Instruction", modelResponseTitle = "Response", middleSystemMessageTitle = "System" + userMessageTitle = "Instruction", modelResponseTitle = "Response", middleSystemMessageTitle = "System", + allowSpecialTokensInTitles = false }: { - userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string + userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string, allowSpecialTokensInTitles?: boolean } = {}) { super({ userMessageTitle: userMessageTitle + ":", modelResponseTitle: modelResponseTitle + ":", - middleSystemMessageTitle: middleSystemMessageTitle + ":" + middleSystemMessageTitle: middleSystemMessageTitle + ":", + allowSpecialTokensInTitles }); } @@ -26,4 +28,11 @@ export class AlpacaChatWrapper extends GeneralChatWrapper { public override get middleSystemMessageTitle() { return super.middleSystemMessageTitle.slice(0, -1); } + + /** @internal */ + public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { + return [{ + allowSpecialTokensInTitles: true + }] satisfies Partial[0]>[]; + } } diff --git a/src/chatWrappers/ChatMLChatWrapper.ts b/src/chatWrappers/ChatMLChatWrapper.ts index 456f9035..18fb57ed 100644 --- a/src/chatWrappers/ChatMLChatWrapper.ts +++ b/src/chatWrappers/ChatMLChatWrapper.ts @@ -63,12 +63,14 @@ export class ChatMLChatWrapper extends ChatWrapper { currentAggregateFocus = null; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); const contextText = LlamaText( + new BuiltinSpecialToken("BOS"), resultItems.map(({system, user, model}, index) => { const isLastItem = index === resultItems.length - 1; diff --git a/src/chatWrappers/FalconChatWrapper.ts b/src/chatWrappers/FalconChatWrapper.ts index 63154262..36096ccf 100644 --- a/src/chatWrappers/FalconChatWrapper.ts +++ b/src/chatWrappers/FalconChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {LlamaText, BuiltinSpecialToken} from "../utils/LlamaText.js"; +import {LlamaText, BuiltinSpecialToken, SpecialToken} from "../utils/LlamaText.js"; export class FalconChatWrapper extends ChatWrapper { public readonly wrapperName: string = "Falcon"; @@ -8,17 +8,19 @@ export class FalconChatWrapper extends ChatWrapper { /** @internal */ private readonly _userMessageTitle: string; /** @internal */ private readonly _modelResponseTitle: string; /** @internal */ private readonly _middleSystemMessageTitle: string; + /** @internal */ private readonly _allowSpecialTokensInTitles: boolean; public constructor({ - userMessageTitle = "User", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System" + userMessageTitle = "User", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System", allowSpecialTokensInTitles = false }: { - userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string + userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string, allowSpecialTokensInTitles?: boolean } = {}) { super(); this._userMessageTitle = userMessageTitle; this._modelResponseTitle = modelResponseTitle; this._middleSystemMessageTitle = middleSystemMessageTitle; + this._allowSpecialTokensInTitles = allowSpecialTokensInTitles; } public get userMessageTitle() { @@ -85,7 +87,8 @@ export class FalconChatWrapper extends ChatWrapper { currentAggregateFocus = null; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); @@ -102,27 +105,27 @@ export class FalconChatWrapper extends ChatWrapper { : LlamaText([ isFirstItem ? LlamaText([]) - : `${this._middleSystemMessageTitle}: `, + : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._middleSystemMessageTitle}: `), system, - "\n\n" + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (user.length === 0) ? LlamaText([]) : LlamaText([ - `${this._userMessageTitle}: `, + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._userMessageTitle}: `), user, - "\n\n" + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - `${this._modelResponseTitle}: `, + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._modelResponseTitle}: `), model, isLastItem ? LlamaText([]) - : "\n\n" + : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]) ]); }) @@ -135,8 +138,25 @@ export class FalconChatWrapper extends ChatWrapper { LlamaText(`\n${this._userMessageTitle}:`), LlamaText(`\n${this._modelResponseTitle}:`), - LlamaText(`\n${this._middleSystemMessageTitle}:`) + LlamaText(`\n${this._middleSystemMessageTitle}:`), + + ...( + !this._allowSpecialTokensInTitles + ? [] + : [ + LlamaText(new SpecialToken(`\n${this._userMessageTitle}:`)), + LlamaText(new SpecialToken(`\n${this._modelResponseTitle}:`)), + LlamaText(new SpecialToken(`\n${this._middleSystemMessageTitle}:`)) + ] + ) ] }; } + + /** @internal */ + public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { + return [{ + allowSpecialTokensInTitles: true + }] satisfies Partial[0]>[]; + } } diff --git a/src/chatWrappers/FunctionaryChatWrapper.ts b/src/chatWrappers/FunctionaryChatWrapper.ts index e6127a79..8c4e65cf 100644 --- a/src/chatWrappers/FunctionaryChatWrapper.ts +++ b/src/chatWrappers/FunctionaryChatWrapper.ts @@ -11,13 +11,13 @@ export class FunctionaryChatWrapper extends ChatWrapper { functions: { call: { optionalPrefixSpace: true, - prefix: "\n\n<|from|>assistant\n<|recipient|>", + prefix: "\n<|from|>assistant\n<|recipient|>", paramsPrefix: "\n<|content|>", - suffix: "\n\n" + suffix: "\n" }, result: { prefix: "<|from|>{{functionName}}\n<|recipient|>all\n<|content|>", - suffix: "\n\n" + suffix: "\n" } } }; @@ -53,20 +53,20 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ isFirstItem ? LlamaText([]) - : "\n\n", - "<|from|>system\n", - "<|recipient|>all\n", - "<|content|>", + : new SpecialToken("\n"), + new SpecialToken("<|from|>system\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>"), item.text ]); } else if (item.type === "user") { return LlamaText([ isFirstItem ? LlamaText([]) - : "\n\n", - "<|from|>user\n", - "<|recipient|>all\n", - "<|content|>", + : new SpecialToken("\n"), + new SpecialToken("<|from|>user\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>"), item.text ]); } else if (item.type === "model") { @@ -74,10 +74,10 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ isFirstItem ? LlamaText([]) - : "\n\n", - "<|from|>assistant\n", - "<|recipient|>all\n", - "<|content|>" + : new SpecialToken("\n"), + new SpecialToken("<|from|>assistant\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>") ]); return LlamaText( @@ -89,14 +89,14 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) - : "\n\n", - "<|from|>assistant\n", - "<|recipient|>all\n", - "<|content|>", + : new SpecialToken("\n"), + new SpecialToken("<|from|>assistant\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>"), response, (isLastResponse && isLastItem) ? "" - : "<|stop|>" + : new SpecialToken("<|stop|>") ]); else if (isChatModelResponseFunctionCall(response)) { return LlamaText([ @@ -105,20 +105,20 @@ export class FunctionaryChatWrapper extends ChatWrapper { : LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) - : "\n\n", + : new SpecialToken("\n"), - "<|from|>assistant\n", - `<|recipient|>${response.name}\n`, - "<|content|>", + new SpecialToken("<|from|>assistant\n"), + new SpecialToken("<|recipient|>"), response.name, new SpecialToken("\n"), + new SpecialToken("<|content|>"), response.params === undefined ? "" : JSON.stringify(response.params), - "<|stop|>", + new SpecialToken("<|stop|>"), - "\n\n", - `<|from|>${response.name}\n`, - "<|recipient|>all\n", - "<|content|>", + new SpecialToken("\n"), + new SpecialToken("<|from|>"), response.name, new SpecialToken("\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>"), response.result === undefined ? "" // "void" : JSON.stringify(response.result) @@ -127,10 +127,10 @@ export class FunctionaryChatWrapper extends ChatWrapper { hasFunctions ? LlamaText([]) : LlamaText([ - "\n\n", - "<|from|>assistant\n", - "<|recipient|>all\n", - "<|content|>" + new SpecialToken("\n"), + new SpecialToken("<|from|>assistant\n"), + new SpecialToken("<|recipient|>all\n"), + new SpecialToken("<|content|>") ]) ]); } @@ -154,9 +154,15 @@ export class FunctionaryChatWrapper extends ChatWrapper { LlamaText(new SpecialToken("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), - LlamaText("\n\n<|from|>user"), - LlamaText("\n\n<|from|>assistant"), - LlamaText("\n\n<|from|>system") + LlamaText("\n<|from|>user"), + LlamaText("\n<|from|>assistant"), + LlamaText("\n<|from|>system"), + + LlamaText(new SpecialToken(" <|stop|>")), + LlamaText(new SpecialToken("<|stop|>")), + LlamaText(new SpecialToken("\n<|from|>user")), + LlamaText(new SpecialToken("\n<|from|>assistant")), + LlamaText(new SpecialToken("\n<|from|>system")) ] }; } @@ -166,17 +172,28 @@ export class FunctionaryChatWrapper extends ChatWrapper { stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), LlamaText(new SpecialToken("<|stop|>")), + LlamaText(" <|stop|>"), LlamaText("<|stop|>"), - LlamaText("\n\n<|from|>user") + LlamaText("\n<|from|>user"), + + LlamaText(new SpecialToken(" <|stop|>")), + LlamaText(new SpecialToken("<|stop|>")), + LlamaText(new SpecialToken("\n<|from|>user")) ], ignoreStartText: [ - LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>") + LlamaText("\n<|from|>assistant\n<|recipient|>all\n<|content|>"), + LlamaText(new SpecialToken("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), + LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>"), + LlamaText(new SpecialToken("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) ], functionCall: { initiallyEngaged: true, disengageInitiallyEngaged: [ - LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>") + LlamaText("\n<|from|>assistant\n<|recipient|>all\n<|content|>"), + LlamaText(new SpecialToken("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), + LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>"), + LlamaText(new SpecialToken("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) ] } }; diff --git a/src/chatWrappers/GemmaChatWrapper.ts b/src/chatWrappers/GemmaChatWrapper.ts index c0cd7090..4aab2dc3 100644 --- a/src/chatWrappers/GemmaChatWrapper.ts +++ b/src/chatWrappers/GemmaChatWrapper.ts @@ -73,7 +73,8 @@ export class GemmaChatWrapper extends ChatWrapper { } else if (item.type === "model") { currentAggregateFocus = "model"; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); diff --git a/src/chatWrappers/GeneralChatWrapper.ts b/src/chatWrappers/GeneralChatWrapper.ts index 75acb1fa..7f1e1daa 100644 --- a/src/chatWrappers/GeneralChatWrapper.ts +++ b/src/chatWrappers/GeneralChatWrapper.ts @@ -8,17 +8,20 @@ export class GeneralChatWrapper extends ChatWrapper { /** @internal */ private readonly _userMessageTitle: string; /** @internal */ private readonly _modelResponseTitle: string; /** @internal */ private readonly _middleSystemMessageTitle: string; + /** @internal */ private readonly _allowSpecialTokensInTitles: boolean; public constructor({ - userMessageTitle = "Human", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System" + userMessageTitle = "Human", modelResponseTitle = "Assistant", middleSystemMessageTitle = "System", + allowSpecialTokensInTitles = false }: { - userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string + userMessageTitle?: string, modelResponseTitle?: string, middleSystemMessageTitle?: string, allowSpecialTokensInTitles?: boolean } = {}) { super(); this._userMessageTitle = userMessageTitle; this._modelResponseTitle = modelResponseTitle; this._middleSystemMessageTitle = middleSystemMessageTitle; + this._allowSpecialTokensInTitles = allowSpecialTokensInTitles; } public get userMessageTitle() { @@ -85,7 +88,8 @@ export class GeneralChatWrapper extends ChatWrapper { currentAggregateFocus = null; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); @@ -102,27 +106,27 @@ export class GeneralChatWrapper extends ChatWrapper { : LlamaText([ isFirstItem ? LlamaText([]) - : `### ${this._middleSystemMessageTitle}\n`, + : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._middleSystemMessageTitle}\n`), system, - "\n\n" + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (user.length === 0) ? LlamaText([]) : LlamaText([ - `### ${this._userMessageTitle}\n`, + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._userMessageTitle}\n`), user, - "\n\n" + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - `### ${this._modelResponseTitle}\n`, + SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._modelResponseTitle}\n`), model, isLastItem ? LlamaText([]) - : "\n\n" + : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]) ]); }) @@ -145,8 +149,33 @@ export class GeneralChatWrapper extends ChatWrapper { LlamaText(`### ${this._middleSystemMessageTitle}`), LlamaText(`\n### ${this._middleSystemMessageTitle}`), - LlamaText(`\n\n### ${this._middleSystemMessageTitle}`) + LlamaText(`\n\n### ${this._middleSystemMessageTitle}`), + + ...( + !this._allowSpecialTokensInTitles + ? [] + : [ + LlamaText(new SpecialToken(`### ${this._userMessageTitle}`)), + LlamaText(new SpecialToken(`\n### ${this._userMessageTitle}`)), + LlamaText(new SpecialToken(`\n\n### ${this._userMessageTitle}`)), + + LlamaText(new SpecialToken(`### ${this._modelResponseTitle}`)), + LlamaText(new SpecialToken(`\n### ${this._modelResponseTitle}`)), + LlamaText(new SpecialToken(`\n\n### ${this._modelResponseTitle}`)), + + LlamaText(new SpecialToken(`### ${this._middleSystemMessageTitle}`)), + LlamaText(new SpecialToken(`\n### ${this._middleSystemMessageTitle}`)), + LlamaText(new SpecialToken(`\n\n### ${this._middleSystemMessageTitle}`)) + ] + ) ] }; } + + /** @internal */ + public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { + return [{ + allowSpecialTokensInTitles: true + }] satisfies Partial[0]>[]; + } } diff --git a/src/chatWrappers/LlamaChatWrapper.ts b/src/chatWrappers/LlamaChatWrapper.ts index f50416da..f40a631e 100644 --- a/src/chatWrappers/LlamaChatWrapper.ts +++ b/src/chatWrappers/LlamaChatWrapper.ts @@ -6,6 +6,20 @@ import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.j export class LlamaChatWrapper extends ChatWrapper { public readonly wrapperName: string = "LlamaChat"; + /** @internal */ private readonly _addSpaceBeforeEos: boolean; + + /** @internal */ + public constructor({ + _addSpaceBeforeEos = true + }: { + /** @internal */ + _addSpaceBeforeEos?: boolean + } = {}) { + super(); + + this._addSpaceBeforeEos = _addSpaceBeforeEos; + } + public override generateContextText(history: readonly ChatHistoryItem[], {availableFunctions, documentFunctionParams}: { availableFunctions?: ChatModelFunctions, documentFunctionParams?: boolean @@ -57,7 +71,8 @@ export class LlamaChatWrapper extends ChatWrapper { } else if (item.type === "model") { currentAggregateFocus = "model"; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); @@ -80,9 +95,12 @@ export class LlamaChatWrapper extends ChatWrapper { new SpecialToken("\n<>\n\n") ]), user, - new SpecialToken(" [/INST]\n\n") + new SpecialToken(" [/INST] ") ]), model, + this._addSpaceBeforeEos + ? " " + : "", isLastItem ? LlamaText([]) : new BuiltinSpecialToken("EOS") @@ -98,4 +116,11 @@ export class LlamaChatWrapper extends ChatWrapper { ] }; } + + /** @internal */ + public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { + return [{}, { + _addSpaceBeforeEos: false + }] satisfies Partial[0]>[]; + } } diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts new file mode 100644 index 00000000..0ffd805c --- /dev/null +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -0,0 +1,458 @@ +import {Template} from "@huggingface/jinja"; +import {splitText} from "lifecycle-utils"; +import {ChatHistoryItem, ChatModelFunctions, ChatUserMessage} from "../../types.js"; +import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../../utils/LlamaText.js"; +import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; +import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; + +export type JinjaTemplateChatWrapperOptions = { + template: string, + modelRoleName?: string, + userRoleName?: string, + systemRoleName?: string, + convertUnsupportedSystemMessagesToUserMessages?: boolean | "auto" | ConvertMessageFormatOptions, + functionCallMessageTemplate?: ChatHistoryFunctionCallMessageTemplate, + joinAdjacentMessagesOfTheSameType?: boolean +}; + +type ConvertMessageFormatOptions = { + use?: "always" | "ifNeeded", + format: `${string}{{message}}${string}` +}; + +const defaultConvertUnsupportedSystemMessagesToUserMessagesFormat: ConvertMessageFormatOptions = { + format: "System: {{message}}" +}; + +/** + * A chat wrapper based on a Jinja template. + * Useful for using the original model's Jinja template as-is without any additional conversion work to chat with a model. + * + * If you want to create a new chat wrapper from scratch, using this chat wrapper is not recommended, and instead you better inherit + * from the `ChatWrapper` class and implement a custom chat wrapper of your own in TypeScript. + * + * For a simpler way to create a chat wrapper, see the `TemplateChatWrapper` class. + */ +export class JinjaTemplateChatWrapper extends ChatWrapper { + public readonly wrapperName = "JinjaTemplate"; + public override readonly settings: ChatWrapperSettings; + + public readonly template: string; + public readonly modelRoleName: string; + public readonly userRoleName: string; + public readonly systemRoleName: string; + public readonly convertUnsupportedSystemMessagesToUserMessages?: ConvertMessageFormatOptions; + public readonly joinAdjacentMessagesOfTheSameType: boolean; + + /** @internal */ private readonly _jinjaTemplate: Template; + + public constructor({ + template, + modelRoleName = "assistant", + userRoleName = "user", + systemRoleName = "system", + convertUnsupportedSystemMessagesToUserMessages = defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, + functionCallMessageTemplate, + joinAdjacentMessagesOfTheSameType = true + }: JinjaTemplateChatWrapperOptions) { + super(); + + if (template == null) + throw new Error("template cannot be null"); + + this.template = template; + this.modelRoleName = modelRoleName; + this.userRoleName = userRoleName; + this.systemRoleName = systemRoleName; + this.convertUnsupportedSystemMessagesToUserMessages = + resolveConvertUnsupportedSystemMessagesToUserMessagesOption(convertUnsupportedSystemMessagesToUserMessages); + this.joinAdjacentMessagesOfTheSameType = joinAdjacentMessagesOfTheSameType; + + this.settings = { + ...super.settings, + functions: parseFunctionCallMessageTemplate(functionCallMessageTemplate) ?? ChatWrapper.defaultSetting.functions + }; + + if (this.convertUnsupportedSystemMessagesToUserMessages != null && !this.convertUnsupportedSystemMessagesToUserMessages.format.includes("{{message}}")) + throw new Error('convertUnsupportedSystemMessagesToUserMessages format must include "{{message}}"'); + + this._jinjaTemplate = new Template(this.template); + this._runSanityTest(); + } + + public override generateContextText(history: readonly ChatHistoryItem[], {availableFunctions, documentFunctionParams}: { + availableFunctions?: ChatModelFunctions, + documentFunctionParams?: boolean + } = {}): { + contextText: LlamaText, + stopGenerationTriggers: LlamaText[] + } { + const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(history, availableFunctions, { + documentParams: documentFunctionParams + }); + + if (this.convertUnsupportedSystemMessagesToUserMessages == null) { + return this._generateContextText(historyWithFunctions, { + convertSystemMessagesToUserMessagesFormat: undefined + }); + } else if (this.convertUnsupportedSystemMessagesToUserMessages.use === "always") { + return this._generateContextText(historyWithFunctions, { + convertSystemMessagesToUserMessagesFormat: this.convertUnsupportedSystemMessagesToUserMessages.format + }); + } + + try { + return this._generateContextText(historyWithFunctions, { + convertSystemMessagesToUserMessagesFormat: undefined + }); + } catch (error) { + return this._generateContextText(historyWithFunctions, { + convertSystemMessagesToUserMessagesFormat: this.convertUnsupportedSystemMessagesToUserMessages.format + }); + } + } + + /** @internal */ + private _generateContextText(history: readonly ChatHistoryItem[], { + convertSystemMessagesToUserMessagesFormat + }: { + convertSystemMessagesToUserMessagesFormat?: string + }): { + contextText: LlamaText, + stopGenerationTriggers: LlamaText[] + } { + const transformedHistory = convertSystemMessagesToUserMessagesFormat == null + ? history + : history.map((item) => { + if (item.type === "system") + return { + type: "user", + text: convertSystemMessagesToUserMessagesFormat.replaceAll("{{message}}", item.text) + } satisfies ChatUserMessage; + + return item; + }); + + const resultItems: Array<{ + role: "system" | "user" | "model", + content: string + }> = []; + + const currentTexts: string[] = []; + let currentAggregateFocus: "system" | "user" | "model" | null = null; + + function flush() { + if (currentTexts.length > 0 && currentAggregateFocus != null) + resultItems.push({role: currentAggregateFocus, content: currentTexts.join("\n\n")}); + + currentTexts.length = 0; + } + + for (const item of transformedHistory) { + if (item.type === "system") { + if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "system") + flush(); + + currentAggregateFocus = "system"; + currentTexts.push(item.text); + } else if (item.type === "user") { + if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "user") + flush(); + + currentAggregateFocus = "user"; + currentTexts.push(item.text); + } else if (item.type === "model") { + if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "model") + flush(); + + currentAggregateFocus = "model"; + currentTexts.push(this.generateModelResponseText(item.response)); + } else + void (item satisfies never); + } + + const lastItemIsModelMessage = currentAggregateFocus === "model"; + flush(); + + const idsGenerator = new UniqueTemplateId( + this.template + this.modelRoleName + this.userRoleName + this.systemRoleName + + (convertSystemMessagesToUserMessagesFormat ?? "") + resultItems.map(({content}) => content).join("\n\n") + ); + + const jinjaItems: Array<{ + role: string, + content: string + }> = []; + const jinjaRoleMap = { + system: this.systemRoleName, + user: this.userRoleName, + model: this.modelRoleName + } as const; + const idToContent = new Map(); + const modelMessageIds = new Set(); + const messageIds = new Set(); + + for (const resultItem of resultItems) { + const id = idsGenerator.generateId(); + + messageIds.add(id); + idToContent.set(id, resultItem.content); + jinjaItems.push({ + role: jinjaRoleMap[resultItem.role], + content: id + }); + + if (resultItem.role === "model") + modelMessageIds.add(id); + } + + const bosTokenId = idsGenerator.generateId(); + const eosTokenId = idsGenerator.generateId(); + + idToContent.set(bosTokenId, new BuiltinSpecialToken("BOS")); + idToContent.set(eosTokenId, new BuiltinSpecialToken("EOS")); + + const renderJinjaText = () => { + try { + return this._jinjaTemplate.render({ + messages: jinjaItems, + "bos_token": bosTokenId, + "eos_token": eosTokenId + }); + } catch (err) { + return this._jinjaTemplate.render({ + messages: jinjaItems, + "bos_token": bosTokenId, + "eos_token": eosTokenId, + "add_generation_prompt": true + }); + } + }; + + const validateThatAllMessageIdsAreUsed = (parts: ReturnType>) => { + const messageIdsLeft = new Set(messageIds); + + for (const part of parts) { + if (typeof part === "string") + continue; + + messageIdsLeft.delete(part.separator); + } + + if (messageIdsLeft.size !== 0) + throw new Error("Some input messages are not present in the generated Jinja template output"); + }; + + const renderJinjaAndSplitIntoParts = () => { + const splitJinjaParts = splitText(renderJinjaText(), [...idToContent.keys()]); + + if (lastItemIsModelMessage) { + let lastModelResponseIndex = -1; + + for (let i = splitJinjaParts.length - 1; i >= 0; i--) { + const part = splitJinjaParts[i]; + + if (typeof part === "string") + continue; + + if (modelMessageIds.has(part.separator)) { + lastModelResponseIndex = i; + break; + } else if (messageIds.has(part.separator)) { + validateThatAllMessageIdsAreUsed(splitJinjaParts); + throw new Error("Last message was expected to be a model message, but it was not"); + } + } + + if (lastModelResponseIndex < 0) { + validateThatAllMessageIdsAreUsed(splitJinjaParts); + throw new Error("A model message was expected to be the last message, but it was not found"); + } + + return { + splitJinjaParts: splitJinjaParts.slice(0, lastModelResponseIndex + 1), + stopGenerationJinjaParts: splitJinjaParts.slice(lastModelResponseIndex + 1) + }; + } + + return { + splitJinjaParts, + stopGenerationJinjaParts: [] + }; + }; + + const {splitJinjaParts, stopGenerationJinjaParts} = renderJinjaAndSplitIntoParts(); + + const messageIdsLeftToProcess = new Set(messageIds); + const contextText = LlamaText( + splitJinjaParts.map((part) => { + if (typeof part === "string") + return new SpecialToken(part); // things that are not message content can be tokenized with special tokens + + const message = idToContent.get(part.separator); + + if (message == null) + throw new Error(`Message with id "${part.separator}" not found`); + + messageIdsLeftToProcess.delete(part.separator); + + return message; + }) + ); + + if (messageIdsLeftToProcess.size !== 0) + throw new Error("Some input messages are not present in the generated Jinja template output"); + + return { + contextText, + stopGenerationTriggers: [ + LlamaText(new BuiltinSpecialToken("EOS")), + ...( + stopGenerationJinjaParts.length === 0 + ? [] + : [ + LlamaText( + stopGenerationJinjaParts.map((part) => { + if (typeof part === "string") + return new SpecialToken(part); + + const message = idToContent.get(part.separator); + + if (message == null) + throw new Error(`Message with id "${part.separator}" not found`); + + return message; + }) + ) + ] + ) + ] + }; + } + + /** + * Validate that this Jinja template can be rendered + * @internal + */ + private _runSanityTest() { + try { + for (const chatHistory of chatHistoriesForSanityTest) { + this.generateContextText(chatHistory); + } + } catch (err) { + throw new Error("The provided Jinja template failed that sanity test: " + String(err)); + } + } +} + +class UniqueTemplateId { + public readonly antiText: string; + private readonly _ids = new Set(); + + public constructor(antiText: string) { + this.antiText = antiText; + } + + public generateId(): string { + let id: string; + + do { + id = "W" + (Math.random() + .toString(36) + .slice(2)) + "W"; + } while (this._ids.has(id) || this.antiText.includes(id)); + + this._ids.add(id); + + return id; + } + + public removeId(id: string) { + this._ids.delete(id); + } +} + +function resolveConvertUnsupportedSystemMessagesToUserMessagesOption( + convertUnsupportedSystemMessagesToUserMessages?: JinjaTemplateChatWrapperOptions["convertUnsupportedSystemMessagesToUserMessages"] +): ConvertMessageFormatOptions | undefined { + if (convertUnsupportedSystemMessagesToUserMessages === false) + return undefined; + + if (convertUnsupportedSystemMessagesToUserMessages === true) + return { + ...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, + use: "always" + }; + + if (convertUnsupportedSystemMessagesToUserMessages === "auto") + return { + ...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, + use: "ifNeeded" + }; + + if (typeof convertUnsupportedSystemMessagesToUserMessages === "object") + return { + ...convertUnsupportedSystemMessagesToUserMessages, + use: convertUnsupportedSystemMessagesToUserMessages.use ?? "ifNeeded" + }; + + return {...defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, use: "ifNeeded"}; +} + +const chatHistoriesForSanityTest: ChatHistoryItem[][] = [ + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: [""] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }, { + type: "user", + text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: [""] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }, { + type: "user", + text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }] +]; diff --git a/src/TemplateChatWrapper.ts b/src/chatWrappers/generic/TemplateChatWrapper.ts similarity index 79% rename from src/TemplateChatWrapper.ts rename to src/chatWrappers/generic/TemplateChatWrapper.ts index 3f3ba438..16223642 100644 --- a/src/TemplateChatWrapper.ts +++ b/src/chatWrappers/generic/TemplateChatWrapper.ts @@ -1,7 +1,8 @@ -import {ChatHistoryItem, ChatModelFunctions} from "./types.js"; -import {BuiltinSpecialToken, LlamaText, LlamaTextValue, SpecialToken} from "./utils/LlamaText.js"; -import {ChatWrapper, ChatWrapperSettings} from "./ChatWrapper.js"; -import {parseTextTemplate} from "./utils/parseTextTemplate.js"; +import {ChatHistoryItem, ChatModelFunctions} from "../../types.js"; +import {BuiltinSpecialToken, LlamaText, LlamaTextValue, SpecialToken} from "../../utils/LlamaText.js"; +import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; +import {parseTextTemplate} from "../../utils/parseTextTemplate.js"; +import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; export type TemplateChatWrapperOptions = { template: ChatTemplate, @@ -13,6 +14,9 @@ export type TemplateChatWrapperOptions = { joinAdjacentMessagesOfTheSameType?: boolean }; +type ChatTemplate = `${`${string}{{systemPrompt}}` | ""}${string}{{history}}${string}{{completion}}${string}`; +type ChatHistoryTemplate = `${string}{{roleName}}${string}{{message}}${string}`; + /** * A chat wrapper based on a simple template. * @@ -69,6 +73,9 @@ export class TemplateChatWrapper extends ChatWrapper { }: TemplateChatWrapperOptions) { super(); + if (template == null || historyTemplate == null || modelRoleName == null || userRoleName == null) + throw new Error("Template chat wrapper settings must have a template, historyTemplate, modelRoleName, and userRoleName."); + this.template = template; this.historyTemplate = historyTemplate; this.modelRoleName = modelRoleName; @@ -102,9 +109,9 @@ export class TemplateChatWrapper extends ChatWrapper { model: string }> = []; - let systemTexts: string[] = []; - let userTexts: string[] = []; - let modelTexts: string[] = []; + const systemTexts: string[] = []; + const userTexts: string[] = []; + const modelTexts: string[] = []; let currentAggregateFocus: "system" | "user" | "model" | null = null; function flush() { @@ -115,20 +122,11 @@ export class TemplateChatWrapper extends ChatWrapper { model: modelTexts.join("\n\n") }); - systemTexts = []; - userTexts = []; - modelTexts = []; + systemTexts.length = 0; + userTexts.length = 0; + modelTexts.length = 0; } - const getHistoryItem = (role: "system" | "user" | "model", text: string, prefix?: string | null) => { - const {roleNamePrefix, messagePrefix, messageSuffix} = this._parsedChatHistoryTemplate; - return LlamaText([ - new SpecialToken((prefix ?? "") + roleNamePrefix + role + messagePrefix), - text, - new SpecialToken(messageSuffix) - ]); - }; - for (const item of historyWithFunctions) { if (item.type === "system") { if (!this.joinAdjacentMessagesOfTheSameType || currentAggregateFocus !== "system") @@ -148,11 +146,21 @@ export class TemplateChatWrapper extends ChatWrapper { currentAggregateFocus = "model"; modelTexts.push(this.generateModelResponseText(item.response)); - } + } else + void (item satisfies never); } flush(); + const getHistoryItem = (role: "system" | "user" | "model", text: string, prefix?: string | null) => { + const {roleNamePrefix, messagePrefix, messageSuffix} = this._parsedChatHistoryTemplate; + return LlamaText([ + new SpecialToken((prefix ?? "") + roleNamePrefix + role + messagePrefix), + text, + new SpecialToken(messageSuffix) + ]); + }; + const contextText = LlamaText( resultItems.map(({system, user, model}, index) => { const isFirstItem = index === 0; @@ -218,68 +226,6 @@ export class TemplateChatWrapper extends ChatWrapper { } } -type ChatTemplate = `${`${string}{{systemPrompt}}` | ""}${string}{{history}}${string}{{completion}}${string}`; -type ChatHistoryTemplate = `${string}{{roleName}}${string}{{message}}${string}`; - -type ChatHistoryFunctionCallMessageTemplate = [ - call: `${string}{{functionName}}${string}{{functionParams}}${string}`, - result: `${string}{{functionCallResult}}${string}` -]; - -function parseFunctionCallMessageTemplate(template?: ChatHistoryFunctionCallMessageTemplate) { - if (template == null) - return null; - - const [functionCallTemplate, functionCallResultTemplate] = template; - - if (functionCallTemplate == null || functionCallResultTemplate == null) - throw new Error("Both function call and function call result templates are required"); - - const parsedFunctionCallTemplate = parseTextTemplate(functionCallTemplate, [{ - text: "{{functionName}}", - key: "functionName" - }, { - text: "{{functionParams}}", - key: "functionParams" - }]); - const parsedFunctionCallResultTemplate = parseTextTemplate(functionCallResultTemplate, [{ - text: "{{functionCallResult}}", - key: "functionCallResult" - }]); - - const callPrefix = parsedFunctionCallTemplate.functionName.prefix; - const callParamsPrefix = parsedFunctionCallTemplate.functionParams.prefix; - const callSuffix = parsedFunctionCallTemplate.functionParams.suffix; - - const resultPrefix = parsedFunctionCallResultTemplate.functionCallResult.prefix; - const resultSuffix = parsedFunctionCallResultTemplate.functionCallResult.suffix; - - if (callPrefix.length === 0) - throw new Error('Function call template must have text before "{{functionName}}"'); - - if (callSuffix.length === 0) - throw new Error('Function call template must have text after "{{functionParams}}"'); - - if (resultPrefix.length === 0) - throw new Error('Function call result template must have text before "{{functionCallResult}}"'); - - if (resultSuffix.length === 0) - throw new Error('Function call result template must have text after "{{functionCallResult}}"'); - - return { - call: { - optionalPrefixSpace: true, - prefix: callPrefix, - paramsPrefix: callParamsPrefix, - suffix: callSuffix - }, - result: { - prefix: resultPrefix, - suffix: resultSuffix - } - }; -} - function parseChatTemplate(template: ChatTemplate): { systemPromptPrefix: string | null, historyPrefix: string, diff --git a/src/chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.ts b/src/chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.ts new file mode 100644 index 00000000..50fbb08e --- /dev/null +++ b/src/chatWrappers/generic/utils/chatHistoryFunctionCallMessageTemplate.ts @@ -0,0 +1,77 @@ +import {parseTextTemplate} from "../../../utils/parseTextTemplate.js"; + +export function parseFunctionCallMessageTemplate(template?: ChatHistoryFunctionCallMessageTemplate) { + if (template == null) + return null; + + const [functionCallTemplate, functionCallResultTemplate] = template; + + if (functionCallTemplate == null || functionCallResultTemplate == null) + throw new Error("Both function call and function call result templates are required"); + + const parsedFunctionCallTemplate = parseTextTemplate(functionCallTemplate, [{ + text: "{{functionName}}", + key: "functionName" + }, { + text: "{{functionParams}}", + key: "functionParams" + }]); + const parsedFunctionCallResultTemplate = parseTextTemplate(functionCallResultTemplate, [{ + text: "{{functionCallResult}}", + key: "functionCallResult" + }]); + + const callPrefix = parsedFunctionCallTemplate.functionName.prefix; + const callParamsPrefix = parsedFunctionCallTemplate.functionParams.prefix; + const callSuffix = parsedFunctionCallTemplate.functionParams.suffix; + + const resultPrefix = parsedFunctionCallResultTemplate.functionCallResult.prefix; + const resultSuffix = parsedFunctionCallResultTemplate.functionCallResult.suffix; + + if (callPrefix.length === 0) + throw new Error("Function call template must have text before \"{{functionName}}\""); + + if (callSuffix.length === 0) + throw new Error("Function call template must have text after \"{{functionParams}}\""); + + if (resultPrefix.length === 0) + throw new Error("Function call result template must have text before \"{{functionCallResult}}\""); + + if (resultSuffix.length === 0) + throw new Error("Function call result template must have text after \"{{functionCallResult}}\""); + + return { + call: { + optionalPrefixSpace: true, + prefix: callPrefix, + paramsPrefix: callParamsPrefix, + suffix: callSuffix + }, + result: { + prefix: resultPrefix, + suffix: resultSuffix + } + }; +} + +/** + * Template format for how functions can be called by the model and how their results are fed to the model after the function call. + * Consists of an array with two elements: + * 1. The function call template. + * 2. The function call result template. + * + * For example: + * ```typescript + * const template: ChatHistoryFunctionCallMessageTemplate = [ + * "[[call: {{functionName}}({{functionParams}})]]" + * " [[result: {{functionCallResult}}]]" + * ]; + * ``` + * + * It's mandatory for the call template to have text before `{{functionName}}` in order for the chat wrapper know when + * to activate the function calling grammar. + */ +export type ChatHistoryFunctionCallMessageTemplate = [ + call: `${string}{{functionName}}${string}{{functionParams}}${string}`, + result: `${string}{{functionCallResult}}${string}` +]; diff --git a/src/chatWrappers/resolveChatWrapperBasedOnModel.ts b/src/chatWrappers/resolveChatWrapperBasedOnModel.ts deleted file mode 100644 index 7aeb9773..00000000 --- a/src/chatWrappers/resolveChatWrapperBasedOnModel.ts +++ /dev/null @@ -1,79 +0,0 @@ -import {parseModelFileName} from "../utils/parseModelFileName.js"; -import {LlamaChatWrapper} from "./LlamaChatWrapper.js"; -import {ChatMLChatWrapper} from "./ChatMLChatWrapper.js"; -import {GeneralChatWrapper} from "./GeneralChatWrapper.js"; -import {FalconChatWrapper} from "./FalconChatWrapper.js"; -import {FunctionaryChatWrapper} from "./FunctionaryChatWrapper.js"; -import {AlpacaChatWrapper} from "./AlpacaChatWrapper.js"; -import {GemmaChatWrapper} from "./GemmaChatWrapper.js"; -import type {GgufFileInfo} from "../gguf/types/GgufFileInfoTypes.js"; - - -/** - * @param options - */ -export function resolveChatWrapperBasedOnModel({ - bosString, - filename, - fileInfo -}: { - bosString?: string | null, - filename?: string, - fileInfo?: GgufFileInfo -}) { - if (filename != null) { - const {name, subType, fileType} = parseModelFileName(filename); - - if (fileType?.toLowerCase() === "gguf") { - const lowercaseName = name?.toLowerCase(); - const lowercaseSubType = subType?.toLowerCase(); - const splitLowercaseSubType = lowercaseSubType?.split("-") ?? []; - const firstSplitLowercaseSubType = splitLowercaseSubType[0]; - - if (lowercaseName === "llama") { - if (splitLowercaseSubType.includes("chat")) - return LlamaChatWrapper; - - return GeneralChatWrapper; - } else if (lowercaseName === "yarn" && firstSplitLowercaseSubType === "llama") - return LlamaChatWrapper; - else if (lowercaseName === "orca") - return ChatMLChatWrapper; - else if (lowercaseName === "phind" && lowercaseSubType === "codellama") - return LlamaChatWrapper; - else if (lowercaseName === "mistral") - return GeneralChatWrapper; - else if (firstSplitLowercaseSubType === "llama") - return LlamaChatWrapper; - else if (lowercaseSubType === "alpaca") - return AlpacaChatWrapper; - else if (lowercaseName === "functionary") - return FunctionaryChatWrapper; - else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral")) - return ChatMLChatWrapper; - else if (lowercaseName === "gemma") - return GemmaChatWrapper; - } - } - - if (fileInfo != null) { - const arch = fileInfo.metadata.general?.architecture; - - if (arch === "llama") - return LlamaChatWrapper; - else if (arch === "falcon") - return FalconChatWrapper; - } - - if (bosString === "" || bosString == null) - return null; - - if ("[INST] <>\n".startsWith(bosString)) { - return LlamaChatWrapper; - } else if ("<|im_start|>system\n".startsWith(bosString)) { - return ChatMLChatWrapper; - } - - return null; -} - diff --git a/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts new file mode 100644 index 00000000..6e02829a --- /dev/null +++ b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts @@ -0,0 +1,249 @@ +import {ChatWrapper} from "../../ChatWrapper.js"; +import {ChatHistoryItem, ChatModelResponse, ChatUserMessage, Tokenizer} from "../../types.js"; +import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js"; +import {BuiltinSpecialToken, LlamaText} from "../../utils/LlamaText.js"; +import {compareTokens} from "../../utils/compareTokens.js"; +import {StopGenerationDetector} from "../../utils/StopGenerationDetector.js"; + +export function isJinjaTemplateEquivalentToSpecializedChatWrapper( + jinjaTemplateWrapperOptions: JinjaTemplateChatWrapperOptions, + specializedChatWrapper: ChatWrapper, + tokenizer?: Tokenizer +): boolean { + const canTestMultipleConvertSystemMessagesToUserMessages = + jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages == null || + jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages === "auto"; + + try { + const jinjaChatWrapper = new JinjaTemplateChatWrapper({ + ...jinjaTemplateWrapperOptions, + convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages + ? false + : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages + }); + + if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, testChatHistories, tokenizer)) + return true; + } catch (err) { + // Do nothing + } + + if (!canTestMultipleConvertSystemMessagesToUserMessages) + return false; + + try { + const convertSystemMessagesToUserMessagesTemplate = "System: {{message}}"; + const jinjaChatWrapper = new JinjaTemplateChatWrapper({ + ...jinjaTemplateWrapperOptions, + convertUnsupportedSystemMessagesToUserMessages: { + use: "always", + format: convertSystemMessagesToUserMessagesTemplate + } + }); + + const transformedTestChatHistories = testChatHistories + .map((history) => ( + history + .slice() + .map((item, index, array) => { + if (item.type === "system") { + if (index === 0 && array.length > 1 && array[1].type === "user") { + array[1] = { + type: "user", + text: convertSystemMessagesToUserMessagesTemplate.replace("{{message}}", item.text) + "\n\n" + array[1].text + } satisfies ChatHistoryItem; + return null; + } + + return { + type: "user", + text: convertSystemMessagesToUserMessagesTemplate.replace("{{message}}", item.text) + } satisfies ChatHistoryItem; + } + + return item; + }) + .filter((item): item is ChatUserMessage | ChatModelResponse => item != null) + )); + + if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, transformedTestChatHistories, tokenizer)) + return true; + } catch (err) { + // Do nothing + } + + return false; +} + +function checkEquivalence( + jinjaChatWrapper: JinjaTemplateChatWrapper, + specializedChatWrapper: ChatWrapper, + testChatHistories: ChatHistoryItem[][], + tokenizer?: Tokenizer +): boolean { + for (const testChatHistory of testChatHistories) { + const jinjaRes = jinjaChatWrapper.generateContextText(testChatHistory); + const specializedWrapperRes = specializedChatWrapper.generateContextText(testChatHistory); + + if (!compareContextTexts(jinjaRes.contextText, specializedWrapperRes.contextText, tokenizer)) + return false; + + const jinjaHasAllSpecializedStopGenerationTriggers = jinjaRes.stopGenerationTriggers + .every((trigger) => { + return [trigger, trigger.trimEnd(), trigger.trimStart(), trigger.trimStart().trimEnd()].some((normalizedJinjaTrigger) => { + if (normalizedJinjaTrigger.values.length === 0) + return true; + + const foundSimilarTriggers = specializedWrapperRes.stopGenerationTriggers.some((specializedTrigger) => ( + normalizedJinjaTrigger.includes(specializedTrigger) + )); + + if (foundSimilarTriggers) + return true; + + if (tokenizer != null) { + const resolvedStopGenerationTrigger = StopGenerationDetector.resolveLlamaTextTrigger( + normalizedJinjaTrigger, + tokenizer + ); + + const foundSimilarOrShorterTokenizedTriggers = specializedWrapperRes.stopGenerationTriggers + .some((specializedTrigger) => { + const resolvedSpecializedTrigger = StopGenerationDetector.resolveLlamaTextTrigger( + specializedTrigger, + tokenizer + ); + + return resolvedSpecializedTrigger.every((item, index) => { + const resolveTriggerItem = resolvedStopGenerationTrigger[index]; + + if (typeof item === "string" && typeof resolveTriggerItem === "string") + return item === resolveTriggerItem; + else if (typeof item === "string" || typeof resolveTriggerItem === "string") + return false; + + return compareTokens(item, resolveTriggerItem); + }); + }); + + if (foundSimilarOrShorterTokenizedTriggers) + return true; + } + + return false; + }); + }); + + if (!jinjaHasAllSpecializedStopGenerationTriggers) + return false; + } + + return true; +} + +function compareContextTexts(text1: LlamaText, text2: LlamaText, tokenizer?: Tokenizer) { + function compare(text1: LlamaText, text2: LlamaText) { + if (LlamaText.compare(text1, text2)) + return true; + + if (tokenizer != null) { + const tokenizedText1 = text1.tokenize(tokenizer); + const tokenizedText2 = text2.tokenize(tokenizer); + + if (tokenizedText1.length === tokenizedText2.length) + return tokenizedText1.every((token, index) => compareTokens(token, tokenizedText2[index])); + } + + return false; + } + + const trimmedText1 = text1.trimEnd(); + const trimmedText2 = text2.trimEnd(); + + const normalizedText1 = removeLeadingBos(trimmedText1); + const normalizedText2 = removeLeadingBos(trimmedText2); + + const texts1 = (normalizedText1.length !== trimmedText1.length && tokenizer != null) + ? [trimmedText1, normalizedText1] + : [normalizedText1]; + + const texts2 = (normalizedText2.length !== trimmedText2.length && tokenizer != null) + ? [trimmedText2, normalizedText2] + : [normalizedText2]; + + return texts1.some((text1) => ( + texts2.some((text2) => ( + compare(text1, text2) + )) + )); +} + +const testChatHistories: ChatHistoryItem[][] = [ + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: [""] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }, { + type: "user", + text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: [""] + }], + + [{ + type: "system", + text: "System message ~!@#$%^&*()\n*" + }, { + type: "user", + text: "Message 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }, { + type: "user", + text: "Message2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~" + }, { + type: "model", + response: ["Result2 1234567890!@#$%^&*()_+-=[]{}|\\:;\"',./<>?`~"] + }] +]; + +function removeLeadingBos(llamaText: LlamaText) { + if (llamaText.values.length === 0) + return llamaText; + + const firstValue = llamaText.values[0]; + + if (firstValue instanceof BuiltinSpecialToken && firstValue.value === "BOS") + return LlamaText(llamaText.values.slice(1)); + + return llamaText; +} diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts new file mode 100644 index 00000000..6502f648 --- /dev/null +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -0,0 +1,249 @@ +import {parseModelFileName} from "../../utils/parseModelFileName.js"; +import {LlamaChatWrapper} from "../LlamaChatWrapper.js"; +import {ChatMLChatWrapper} from "../ChatMLChatWrapper.js"; +import {GeneralChatWrapper} from "../GeneralChatWrapper.js"; +import {FalconChatWrapper} from "../FalconChatWrapper.js"; +import {FunctionaryChatWrapper} from "../FunctionaryChatWrapper.js"; +import {AlpacaChatWrapper} from "../AlpacaChatWrapper.js"; +import {GemmaChatWrapper} from "../GemmaChatWrapper.js"; +import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js"; +import {TemplateChatWrapper} from "../generic/TemplateChatWrapper.js"; +import {getConsoleLogPrefix} from "../../utils/getConsoleLogPrefix.js"; +import {Tokenizer} from "../../types.js"; +import {isJinjaTemplateEquivalentToSpecializedChatWrapper} from "./isJinjaTemplateEquivalentToSpecializedChatWrapper.js"; +import type {GgufFileInfo} from "../../gguf/types/GgufFileInfoTypes.js"; + + +export const specializedChatWrapperTypeNames = Object.freeze([ + "general", "llamaChat", "alpacaChat", "functionary", "chatML", "falconChat", "gemma" +] as const); +export type SpecializedChatWrapperTypeName = (typeof specializedChatWrapperTypeNames)[number]; + +export const templateChatWrapperTypeNames = Object.freeze([ + "template", "jinjaTemplate" +] as const); +export type TemplateChatWrapperTypeName = (typeof templateChatWrapperTypeNames)[number]; + +export const resolvableChatWrapperTypeNames = Object.freeze([ + "auto", + ...specializedChatWrapperTypeNames, + ...templateChatWrapperTypeNames +] as const); +export type ResolvableChatWrapperTypeName = (typeof resolvableChatWrapperTypeNames)[number]; + +const chatWrappers = { + "general": GeneralChatWrapper, + "llamaChat": LlamaChatWrapper, + "alpacaChat": AlpacaChatWrapper, + "functionary": FunctionaryChatWrapper, + "chatML": ChatMLChatWrapper, + "falconChat": FalconChatWrapper, + "gemma": GemmaChatWrapper, + "template": TemplateChatWrapper, + "jinjaTemplate": JinjaTemplateChatWrapper +} as const satisfies Record; +const chatWrapperToConfigType = new Map( + Object.entries(chatWrappers) + .map(([configType, Wrapper]) => ( + [Wrapper, configType as keyof typeof chatWrappers] + )) +); + +export type ResolveChatWrapperOptions = { + /** + * Resolve to a specific chat wrapper type. + * You better not set this option unless you need to force a specific chat wrapper type. + * + * Defaults to `"auto"`. + */ + type?: "auto" | SpecializedChatWrapperTypeName | TemplateChatWrapperTypeName, + + bosString?: string | null, + filename?: string, + fileInfo?: GgufFileInfo, + tokenizer?: Tokenizer, + customWrapperSettings?: { + [wrapper in keyof typeof chatWrappers]?: ConstructorParameters<(typeof chatWrappers)[wrapper]>[0] + }, + warningLogs?: boolean, + fallbackToOtherWrappersOnJinjaError?: boolean +}; + +/** + * Resolve to a chat wrapper instance based on the provided information. + * The more information provided, the better the resolution will be (except for `type`). + * + * It's recommended to not set `type` to a specific chat wrapper in order for the resolution to be more flexible, but it is useful for when + * you need to provide the ability to force a specific chat wrapper type. + * Note that when setting `type` to a generic chat wrapper type (such as `"template"` or `"jinjaTemplate"`), the `customWrapperSettings` + * must contain the necessary settings for that chat wrapper to be created. + * + * When loading a Jinja chat template from either `fileInfo` or `customWrapperSettings.jinjaTemplate.template`, + * if the chat template format is invalid, it fallbacks to resolve other chat wrappers, + * unless `fallbackToOtherWrappersOnJinjaError` is set to `false` (in which case, it will throw an error). + */ +export function resolveChatWrapper({ + type = "auto", + bosString, + filename, + fileInfo, + tokenizer, + customWrapperSettings, + warningLogs = true, + fallbackToOtherWrappersOnJinjaError = true +}: ResolveChatWrapperOptions) { + function createSpecializedChatWrapper(specializedChatWrapper: typeof chatWrappers[SpecializedChatWrapperTypeName]) { + const chatWrapperConfigType = chatWrapperToConfigType.get(specializedChatWrapper) as SpecializedChatWrapperTypeName; + const chatWrapperSettings = customWrapperSettings?.[chatWrapperConfigType]; + + return new (specializedChatWrapper as any)(chatWrapperSettings); + } + + if (type !== "auto" && type != null) { + if (isTemplateChatWrapperType(type)) { + const Wrapper = chatWrappers[type]; + + if (isClassReference(Wrapper, TemplateChatWrapper)) { + const wrapperSettings = customWrapperSettings?.template; + if (wrapperSettings == null || wrapperSettings?.template == null || wrapperSettings?.historyTemplate == null || + wrapperSettings?.modelRoleName == null || wrapperSettings?.userRoleName == null + ) { + if (warningLogs) + console.warn(getConsoleLogPrefix() + "Template chat wrapper settings must have a template, historyTemplate, modelRoleName, and userRoleName. Falling back to resolve other chat wrapper types."); + } else + return new TemplateChatWrapper(wrapperSettings); + } else if (isClassReference(Wrapper, JinjaTemplateChatWrapper)) { + const jinjaTemplate = customWrapperSettings?.jinjaTemplate?.template ?? fileInfo?.metadata?.tokenizer?.chat_template; + + if (jinjaTemplate == null) { + if (warningLogs) + console.warn(getConsoleLogPrefix() + "Jinja template chat wrapper received no template. Falling back to resolve other chat wrapper types."); + } else { + try { + return new JinjaTemplateChatWrapper({ + ...(customWrapperSettings?.jinjaTemplate ?? {}), + template: jinjaTemplate + }); + } catch (err) { + if (!fallbackToOtherWrappersOnJinjaError) + throw err; + else if (warningLogs) + console.error(getConsoleLogPrefix() + "Error creating Jinja template chat wrapper. Falling back to resolve other chat wrappers. Error:", err); + } + } + } else + void (Wrapper satisfies never); + } else if (Object.hasOwn(chatWrappers, type)) { + const Wrapper = chatWrappers[type]; + const wrapperSettings: ConstructorParameters[0] | undefined = + customWrapperSettings?.[type]; + + return new (Wrapper as any)(wrapperSettings); + } + } + + const modelJinjaTemplate = customWrapperSettings?.jinjaTemplate?.template ?? fileInfo?.metadata?.tokenizer?.chat_template; + + if (modelJinjaTemplate != null && modelJinjaTemplate.trim() !== "") { + const jinjaTemplateChatWrapperOptions: JinjaTemplateChatWrapperOptions = { + ...(customWrapperSettings?.jinjaTemplate ?? {}), + template: modelJinjaTemplate + }; + + for (const specializedChatWrapperTypeName of specializedChatWrapperTypeNames) { + const Wrapper = chatWrappers[specializedChatWrapperTypeName]; + const wrapperSettings = customWrapperSettings?.[specializedChatWrapperTypeName]; + + const testOptionConfigurations = Wrapper._getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate?.() ?? []; + if (testOptionConfigurations.length === 0) + testOptionConfigurations.push({} as any); + + for (const testConfiguration of testOptionConfigurations) { + const chatWrapper = new (Wrapper as any)({ + ...(wrapperSettings ?? {}), + ...(testConfiguration ?? {}) + }); + + if (isJinjaTemplateEquivalentToSpecializedChatWrapper(jinjaTemplateChatWrapperOptions, chatWrapper, tokenizer)) + return new (Wrapper as any)(wrapperSettings ?? {}); + } + } + + if (!fallbackToOtherWrappersOnJinjaError) + return new JinjaTemplateChatWrapper(jinjaTemplateChatWrapperOptions); + + try { + return new JinjaTemplateChatWrapper(jinjaTemplateChatWrapperOptions); + } catch (err) { + console.error(getConsoleLogPrefix() + "Error creating Jinja template chat wrapper. Falling back to resolve other chat wrappers. Error:", err); + } + } + + if (filename != null) { + const {name, subType, fileType} = parseModelFileName(filename); + + if (fileType?.toLowerCase() === "gguf") { + const lowercaseName = name?.toLowerCase(); + const lowercaseSubType = subType?.toLowerCase(); + const splitLowercaseSubType = lowercaseSubType?.split("-") ?? []; + const firstSplitLowercaseSubType = splitLowercaseSubType[0]; + + if (lowercaseName === "llama") { + if (splitLowercaseSubType.includes("chat")) + return createSpecializedChatWrapper(LlamaChatWrapper); + + return createSpecializedChatWrapper(GeneralChatWrapper); + } else if (lowercaseName === "yarn" && firstSplitLowercaseSubType === "llama") + return createSpecializedChatWrapper(LlamaChatWrapper); + else if (lowercaseName === "orca") + return createSpecializedChatWrapper(ChatMLChatWrapper); + else if (lowercaseName === "phind" && lowercaseSubType === "codellama") + return createSpecializedChatWrapper(LlamaChatWrapper); + else if (lowercaseName === "mistral") + return createSpecializedChatWrapper(GeneralChatWrapper); + else if (firstSplitLowercaseSubType === "llama") + return createSpecializedChatWrapper(LlamaChatWrapper); + else if (lowercaseSubType === "alpaca") + return createSpecializedChatWrapper(AlpacaChatWrapper); + else if (lowercaseName === "functionary") + return createSpecializedChatWrapper(FunctionaryChatWrapper); + else if (lowercaseName === "dolphin" && splitLowercaseSubType.includes("mistral")) + return createSpecializedChatWrapper(ChatMLChatWrapper); + else if (lowercaseName === "gemma") + return createSpecializedChatWrapper(GemmaChatWrapper); + } + } + + if (fileInfo != null) { + const arch = fileInfo.metadata.general?.architecture; + + if (arch === "llama") + return createSpecializedChatWrapper(LlamaChatWrapper); + else if (arch === "falcon") + return createSpecializedChatWrapper(FalconChatWrapper); + } + + if (bosString === "" || bosString == null) + return null; + + if ("[INST] <>\n".startsWith(bosString)) { + return createSpecializedChatWrapper(LlamaChatWrapper); + } else if ("<|im_start|>system\n".startsWith(bosString)) { + return createSpecializedChatWrapper(ChatMLChatWrapper); + } + + return null; +} + +export function isSpecializedChatWrapperType(type: string): type is SpecializedChatWrapperTypeName { + return specializedChatWrapperTypeNames.includes(type as any); +} + +export function isTemplateChatWrapperType(type: string): type is TemplateChatWrapperTypeName { + return templateChatWrapperTypeNames.includes(type as any); +} + +// this is needed because TypeScript guards don't work automatically with class references +function isClassReference(value: any, classReference: T): value is T { + return value === classReference; +} diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index c8a6e196..5a789f6b 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -14,13 +14,14 @@ import {LlamaGrammar} from "../../evaluator/LlamaGrammar.js"; import {LlamaChatSession} from "../../evaluator/LlamaChatSession/LlamaChatSession.js"; import {LlamaJsonSchemaGrammar} from "../../evaluator/LlamaJsonSchemaGrammar.js"; import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; -import { - ChatWrapperTypeName, chatWrapperTypeNames, resolveChatWrapperBasedOnWrapperTypeName -} from "../../bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.js"; import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; import {getPrettyBuildGpuName} from "../../bindings/consts.js"; +import { + resolveChatWrapper, SpecializedChatWrapperTypeName, specializedChatWrapperTypeNames +} from "../../chatWrappers/utils/resolveChatWrapper.js"; +import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; type ChatCommand = { model: string, @@ -29,7 +30,7 @@ type ChatCommand = { systemPromptFile?: string, prompt?: string, promptFile?: string, - wrapper: ChatWrapperTypeName, + wrapper: SpecializedChatWrapperTypeName | "auto", contextSize?: number, batchSize?: number, grammar: "text" | Parameters[1], @@ -103,7 +104,7 @@ export const ChatCommand: CommandModule = { alias: "w", type: "string", default: "auto" as ChatCommand["wrapper"], - choices: chatWrapperTypeNames, + choices: ["auto", ...specializedChatWrapperTypeNames] as const, description: "Chat wrapper to use. Use `auto` to automatically select a wrapper based on the model's BOS token", group: "Optional:" }) @@ -363,11 +364,13 @@ async function RunChat({ : undefined; const bos = model.tokens.bosString; // bos = beginning of sequence const eos = model.tokens.bosString; // eos = end of sequence - const chatWrapper = resolveChatWrapperBasedOnWrapperTypeName(wrapper, { + const chatWrapper = resolveChatWrapper({ + type: wrapper, bosString: bos, filename: model.filename, - fileInfo: model.fileInfo - }); + fileInfo: model.fileInfo, + tokenizer: model.tokenize + }) ?? new GeneralChatWrapper(); const contextSequence = context.getSequence(); const session = new LlamaChatSession({ contextSequence, diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index dd2362b0..e003c800 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -7,6 +7,7 @@ import fs from "fs-extra"; import {readGgufFileInfo} from "../../../../gguf/readGgufFileInfo.js"; import {prettyPrintObject, PrettyPrintObjectOptions} from "../../../../utils/prettyPrintObject.js"; import {getGgufFileTypeName} from "../../../../gguf/utils/getGgufFileTypeName.js"; +import {normalizeGgufDownloadUrl} from "../../../../gguf/utils/normalizeGgufDownloadUrl.js"; type InspectGgufCommand = { path: string, @@ -55,7 +56,7 @@ export const InspectGgufCommand: CommandModule = { async handler({path: ggufPath, fullTensorInfo, fullMetadataArrays, plainJson, outputToJsonFile}: InspectGgufCommand) { const isPathUrl = ggufPath.startsWith("http://") || ggufPath.startsWith("https://"); const resolvedGgufPath = isPathUrl - ? ggufPath + ? normalizeGgufDownloadUrl(ggufPath) : path.resolve(ggufPath); if (!plainJson) { diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index ce33c25b..a6c6c9fe 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -1,6 +1,5 @@ import {DisposeAggregator, DisposedError, EventRelay} from "lifecycle-utils"; import {ChatWrapper} from "../../ChatWrapper.js"; -import {resolveChatWrapper} from "../../utils/resolveChatWrapper.js"; import {LlamaContextSequence} from "../LlamaContext/LlamaContext.js"; import {ChatHistoryItem, ChatModelFunctions, ChatModelResponse, LLamaContextualRepeatPenalty, Token, Tokenizer} from "../../types.js"; import {GbnfJsonSchemaToType} from "../../utils/gbnfJson/types.js"; @@ -13,6 +12,8 @@ import {QueuedTokenReleaseLock, TokenStreamRegulator} from "../../utils/TokenStr import {EvaluationPriority} from "../LlamaContext/types.js"; import {UNKNOWN_UNICODE_CHAR} from "../../consts.js"; import {getQueuedTokensBeforeStopTrigger} from "../../utils/getQueuedTokensBeforeStopTrigger.js"; +import {resolveChatWrapper} from "../../chatWrappers/utils/resolveChatWrapper.js"; +import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; import { eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy } from "./utils/contextShiftStrategies/eraseFirstResponseAndKeepFirstSystemChatContextShiftStrategy.js"; @@ -180,7 +181,16 @@ export class LlamaChat { ); this._disposeAggregator.add(this.onDispose.dispatchEvent); - this._chatWrapper = resolveChatWrapper(chatWrapper, contextSequence.model); + this._chatWrapper = chatWrapper === "auto" + ? ( + resolveChatWrapper({ + bosString: contextSequence.model.tokens.bosString, + filename: contextSequence.model.filename, + fileInfo: contextSequence.model.fileInfo, + tokenizer: contextSequence.model.tokenize + }) ?? new GeneralChatWrapper() + ) + : chatWrapper; } public dispose({disposeSequence = this._autoDisposeSequence}: {disposeSequence?: boolean} = {}) { diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index dd93b116..ff44ea6f 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -3,15 +3,17 @@ import path from "path"; import {AsyncDisposeAggregator, DisposedError, EventRelay, withLock} from "lifecycle-utils"; import {removeNullFields} from "../utils/removeNullFields.js"; import {Token} from "../types.js"; -import {ModelTypeDescription, AddonModel} from "../bindings/AddonTypes.js"; +import {AddonModel, ModelTypeDescription} from "../bindings/AddonTypes.js"; import {DisposalPreventionHandle, DisposeGuard} from "../utils/DisposeGuard.js"; -import {BuildGpu, LlamaLocks} from "../bindings/types.js"; +import {BuildGpu, LlamaLocks, LlamaVocabularyType, LlamaVocabularyTypeValues} from "../bindings/types.js"; import {GgufFileInfo} from "../gguf/types/GgufFileInfoTypes.js"; import {readGgufFileInfo} from "../gguf/readGgufFileInfo.js"; import {GgufInsights} from "../gguf/GgufInsights.js"; import {findBestOption} from "../utils/findBestOption.js"; import {InsufficientMemoryError} from "../utils/InsufficientMemoryError.js"; import {minAllowedContextSizeInCalculations} from "../config.js"; +import {GgufMetadataTokenizerTokenType} from "../gguf/types/GgufMetadataTypes.js"; +import {getConsoleLogPrefix} from "../utils/getConsoleLogPrefix.js"; import {LlamaContextOptions} from "./LlamaContext/types.js"; import {getDefaultContextBatchSize, getDefaultModelContextSize, LlamaContext} from "./LlamaContext/LlamaContext.js"; import {LlamaEmbeddingContext, LlamaEmbeddingContextOptions} from "./LlamaEmbeddingContext.js"; @@ -92,6 +94,7 @@ export class LlamaModel { /** @internal */ private _typeDescription?: ModelTypeDescription; /** @internal */ private _trainContextSize?: number; /** @internal */ private _embeddingVectorSize?: number; + /** @internal */ private _vocabularyType?: LlamaVocabularyType; public readonly onDispose = new EventRelay(); @@ -215,9 +218,9 @@ export class LlamaModel { * For example, `` will be tokenized to the BOS token if `specialTokens` is set to `true`, * otherwise it will be tokenized to tokens that corresponds to the plaintext `` string. */ - public tokenize(text: string, specialTokens?: boolean): Token[]; + public tokenize(text: string, specialTokens?: boolean | "trimLeadingSpace"): Token[]; public tokenize(text: BuiltinSpecialTokenValue, specialTokens: "builtin"): Token[]; - public tokenize(text: string, specialTokens: boolean | "builtin" = false): Token[] { + public tokenize(text: string, specialTokens: boolean | "builtin" | "trimLeadingSpace" = false): Token[] { this._ensureNotDisposed(); if (text === "") @@ -236,6 +239,29 @@ export class LlamaModel { throw new Error(`Unknown builtin special token: ${builtinToken}`); } + if (specialTokens === "trimLeadingSpace") { + specialTokens = true; + + const [workaroundToken, workaroundTokenString] = (this.tokens.bos != null && this.tokens.bosString != null) + ? [this.tokens.bos, this.tokens.bosString] + : (this.tokens.eos != null && this.tokens.eosString != null) + ? [this.tokens.eos, this.tokens.eosString] + : (this.tokens.nl != null && this.tokens.nlString != null) + ? [this.tokens.nl, this.tokens.nlString] + : [null, null]; + + if (workaroundToken != null && workaroundTokenString != null) { + const tokens = Array.from(this._model.tokenize(workaroundTokenString + text, true)) as Token[]; + const workaroundTokenIndex = tokens.indexOf(workaroundToken); + + // only use the tokenized output if it can be corrected, otherwise fallback to the default tokenization + if (workaroundTokenIndex >= 0 && workaroundTokenIndex <= 1) { + tokens.splice(0, workaroundTokenIndex + 1); + return tokens; + } + } + } + return Array.from(this._model.tokenize(text, specialTokens)) as Token[]; } @@ -249,6 +275,13 @@ export class LlamaModel { return this._model.detokenize(Uint32Array.from(tokens)); } + public getTokenType(token: Token): GgufMetadataTokenizerTokenType | null { + if (this.vocabularyType === LlamaVocabularyType.none) + return null; + + return this._model.getTokenType(token) as GgufMetadataTokenizerTokenType; + } + public async createContext(options: LlamaContextOptions) { return await withLock(this._llama._memoryLock, LlamaLocks.loadToMemory, options.createSignal, async () => { const preventDisposalHandle = this._backendModelDisposeGuard.createPreventDisposalHandle(); @@ -301,6 +334,22 @@ export class LlamaModel { return this._embeddingVectorSize; } + public get vocabularyType(): LlamaVocabularyType { + this._ensureNotDisposed(); + + if (this._vocabularyType == null) { + const vocabType = this._model.getVocabularyType(); + this._vocabularyType = LlamaVocabularyTypeValues[vocabType]; + + if (this._vocabularyType == null) { + console.warn(getConsoleLogPrefix() + "Unknown vocabulary type:", vocabType); + this._vocabularyType = LlamaVocabularyType.none; + } + } + + return this._vocabularyType; + } + /** @internal */ private _ensureNotDisposed() { if (this._disposedState.disposed) diff --git a/src/gguf/readGgufFileInfo.ts b/src/gguf/readGgufFileInfo.ts index 135620e6..eae8afa3 100644 --- a/src/gguf/readGgufFileInfo.ts +++ b/src/gguf/readGgufFileInfo.ts @@ -3,6 +3,7 @@ import {parseGguf} from "./parser/parseGguf.js"; import {GgufNetworkFetchFileReader} from "./fileReaders/GgufNetworkFetchFileReader.js"; import {GgufFsFileReader} from "./fileReaders/GgufFsFileReader.js"; import {ggufDefaultFetchRetryOptions} from "./consts.js"; +import {normalizeGgufDownloadUrl} from "./utils/normalizeGgufDownloadUrl.js"; /** @@ -50,7 +51,7 @@ export async function readGgufFileInfo(pathOrUrl: string, { function createFileReader() { if (sourceType === "network" || (sourceType == null && (pathOrUrl.startsWith("http://") || pathOrUrl.startsWith("https://")))) { return new GgufNetworkFetchFileReader({ - url: pathOrUrl, + url: normalizeGgufDownloadUrl(pathOrUrl), retryOptions: fetchRetryOptions, headers: fetchHeaders, signal diff --git a/src/gguf/types/GgufMetadataTypes.ts b/src/gguf/types/GgufMetadataTypes.ts index 923844ba..e6554415 100644 --- a/src/gguf/types/GgufMetadataTypes.ts +++ b/src/gguf/types/GgufMetadataTypes.ts @@ -35,12 +35,12 @@ export type GgufMetadata GgufArchitectureType extends A ? { readonly [key in GgufArchitectureType]?: key extends keyof GgufMetadataLlmToType ? GgufMetadataLlmToType[key] - : GgufMetadataLlmDefaultArchitectureType + : GgufMetadataDefaultArchitectureType } : { readonly [key in A]: key extends keyof GgufMetadataLlmToType ? GgufMetadataLlmToType[key] - : GgufMetadataLlmDefaultArchitectureType + : GgufMetadataDefaultArchitectureType } ); @@ -197,14 +197,14 @@ export type GgufMetadataTokenizer = { readonly chat_template?: string }; -export const enum GgufMetadataLlmPoolingType { +export const enum GgufMetadataArchitecturePoolingType { unspecified = -1, none = 0, mean = 1, max = 2, } -export type GgufMetadataLlmDefaultArchitectureType = { +export type GgufMetadataDefaultArchitectureType = { readonly vocab_size?: number, readonly context_length?: number, readonly embedding_length?: number, @@ -214,7 +214,7 @@ export type GgufMetadataLlmDefaultArchitectureType = { readonly tensor_data_layout?: string, readonly expert_count?: number, readonly expert_used_count?: number, - readonly pooling_type?: GgufMetadataLlmPoolingType, + readonly pooling_type?: GgufMetadataArchitecturePoolingType, readonly logit_scale?: number, readonly attention?: { diff --git a/src/gguf/utils/normalizeGgufDownloadUrl.ts b/src/gguf/utils/normalizeGgufDownloadUrl.ts new file mode 100644 index 00000000..3c957f9e --- /dev/null +++ b/src/gguf/utils/normalizeGgufDownloadUrl.ts @@ -0,0 +1,20 @@ +export function normalizeGgufDownloadUrl(url: string) { + const parsedUrl = new URL(url); + + if (parsedUrl.hostname === "huggingface.co") { + const pathnameParts = parsedUrl.pathname.split("/"); + + if (pathnameParts.length > 3 && pathnameParts[3] === "blob") { + const newUrl = new URL(url); + pathnameParts[3] = "resolve"; + newUrl.pathname = pathnameParts.join("/"); + + if (newUrl.searchParams.get("download") !== "true") + newUrl.searchParams.set("download", "true"); + + return newUrl.href; + } + } + + return url; +} diff --git a/src/index.ts b/src/index.ts index 728dc395..555db6ea 100644 --- a/src/index.ts +++ b/src/index.ts @@ -2,7 +2,7 @@ import {DisposedError} from "lifecycle-utils"; import {Llama} from "./bindings/Llama.js"; import {getLlama, LlamaOptions} from "./bindings/getLlama.js"; import {NoBinaryFoundError} from "./bindings/utils/NoBinaryFoundError.js"; -import {LlamaLogLevel, LlamaLogLevelGreaterThan, LlamaLogLevelGreaterThanOrEqual} from "./bindings/types.js"; +import {LlamaLogLevel, LlamaLogLevelGreaterThan, LlamaLogLevelGreaterThanOrEqual, LlamaVocabularyType} from "./bindings/types.js"; import {LlamaModel, LlamaModelInfillTokens, type LlamaModelOptions, LlamaModelTokens} from "./evaluator/LlamaModel.js"; import {LlamaGrammar, type LlamaGrammarOptions} from "./evaluator/LlamaGrammar.js"; import {LlamaJsonSchemaGrammar} from "./evaluator/LlamaJsonSchemaGrammar.js"; @@ -41,17 +41,20 @@ import {FalconChatWrapper} from "./chatWrappers/FalconChatWrapper.js"; import {AlpacaChatWrapper} from "./chatWrappers/AlpacaChatWrapper.js"; import {FunctionaryChatWrapper} from "./chatWrappers/FunctionaryChatWrapper.js"; import {GemmaChatWrapper} from "./chatWrappers/GemmaChatWrapper.js"; -import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./TemplateChatWrapper.js"; -import {resolveChatWrapperBasedOnModel} from "./chatWrappers/resolveChatWrapperBasedOnModel.js"; +import {TemplateChatWrapper, type TemplateChatWrapperOptions} from "./chatWrappers/generic/TemplateChatWrapper.js"; +import {JinjaTemplateChatWrapper, type JinjaTemplateChatWrapperOptions} from "./chatWrappers/generic/JinjaTemplateChatWrapper.js"; import { - resolveChatWrapperBasedOnWrapperTypeName, chatWrapperTypeNames, type ChatWrapperTypeName -} from "./bindings/utils/resolveChatWrapperBasedOnWrapperTypeName.js"; + resolvableChatWrapperTypeNames, type ResolvableChatWrapperTypeName, specializedChatWrapperTypeNames, + type SpecializedChatWrapperTypeName, templateChatWrapperTypeNames, type TemplateChatWrapperTypeName, resolveChatWrapper, + type ResolveChatWrapperOptions +} from "./chatWrappers/utils/resolveChatWrapper.js"; import { LlamaText, SpecialToken, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, type LlamaTextSpecialTokenJSON } from "./utils/LlamaText.js"; import {appendUserMessageToChatHistory} from "./utils/appendUserMessageToChatHistory.js"; import {getModuleVersion} from "./utils/getModuleVersion.js"; +import {readGgufFileInfo} from "./gguf/readGgufFileInfo.js"; import { type ChatHistoryItem, type ChatModelFunctionCall, type ChatModelFunctions, type ChatModelResponse, @@ -62,6 +65,14 @@ import { type GbnfJsonArraySchema, type GbnfJsonBasicSchema, type GbnfJsonConstSchema, type GbnfJsonEnumSchema, type GbnfJsonObjectSchema, type GbnfJsonOneOfSchema, type GbnfJsonSchema, type GbnfJsonSchemaImmutableType, type GbnfJsonSchemaToType } from "./utils/gbnfJson/types.js"; +import {type GgufFileInfo} from "./gguf/types/GgufFileInfoTypes.js"; +import { + type GgufMetadata, type GgufMetadataLlmToType, GgufArchitectureType, GgufFileType, GgufMetadataTokenizerTokenType, + GgufMetadataArchitecturePoolingType, type GgufMetadataGeneral, type GgufMetadataTokenizer, type GgufMetadataDefaultArchitectureType, + type GgufMetadataLlmLLaMA, type GgufMetadataMPT, type GgufMetadataGPTNeoX, type GgufMetadataGPTJ, type GgufMetadataGPT2, + type GgufMetadataBloom, type GgufMetadataFalcon, type GgufMetadataMamba, type GgufMetadataRWKV, isGgufMetadataOfArchitectureType +} from "./gguf/types/GgufMetadataTypes.js"; +import {GgmlType, type GgufTensorInfo} from "./gguf/types/GgufTensorInfoTypes.js"; export { @@ -130,10 +141,16 @@ export { GemmaChatWrapper, TemplateChatWrapper, type TemplateChatWrapperOptions, - resolveChatWrapperBasedOnModel, - resolveChatWrapperBasedOnWrapperTypeName, - chatWrapperTypeNames, - type ChatWrapperTypeName, + JinjaTemplateChatWrapper, + type JinjaTemplateChatWrapperOptions, + resolveChatWrapper, + type ResolveChatWrapperOptions, + resolvableChatWrapperTypeNames, + type ResolvableChatWrapperTypeName, + specializedChatWrapperTypeNames, + type SpecializedChatWrapperTypeName, + templateChatWrapperTypeNames, + type TemplateChatWrapperTypeName, LlamaText, SpecialToken, BuiltinSpecialToken, @@ -163,6 +180,30 @@ export { type GbnfJsonOneOfSchema, type GbnfJsonObjectSchema, type GbnfJsonArraySchema, + LlamaVocabularyType, LlamaLogLevelGreaterThan, - LlamaLogLevelGreaterThanOrEqual + LlamaLogLevelGreaterThanOrEqual, + readGgufFileInfo, + type GgufFileInfo, + type GgufMetadata, + type GgufTensorInfo, + type GgufMetadataLlmToType, + GgufArchitectureType, + GgufFileType, + GgufMetadataTokenizerTokenType, + GgufMetadataArchitecturePoolingType, + type GgufMetadataGeneral, + type GgufMetadataTokenizer, + type GgufMetadataDefaultArchitectureType, + type GgufMetadataLlmLLaMA, + type GgufMetadataMPT, + type GgufMetadataGPTNeoX, + type GgufMetadataGPTJ, + type GgufMetadataGPT2, + type GgufMetadataBloom, + type GgufMetadataFalcon, + type GgufMetadataMamba, + type GgufMetadataRWKV, + GgmlType, + isGgufMetadataOfArchitectureType }; diff --git a/src/types.ts b/src/types.ts index f4d39cf0..13fd619c 100644 --- a/src/types.ts +++ b/src/types.ts @@ -6,7 +6,7 @@ export type Token = number & { }; export type Tokenizer = { - tokenize(text: string, specialTokens?: boolean): Token[], + tokenize(text: string, specialTokens?: boolean | "trimLeadingSpace"): Token[], tokenize(text: BuiltinSpecialTokenValue, specialTokens: "builtin"): Token[] }["tokenize"]; diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index 0d8a15cd..84c87ac4 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -6,7 +6,8 @@ export type LlamaTextClass = { ...values: readonly (V | LlamaText | V2 | LlamaText | number | boolean | readonly (LlamaText | V | LlamaText | V2)[])[] ): LlamaText, - fromJSON(json: LlamaTextJSON): LlamaText + fromJSON(json: LlamaTextJSON): LlamaText, + compare(a: LlamaText, b: LlamaText): boolean }; export type LlamaText = { @@ -20,7 +21,11 @@ export type LlamaText = { joinValues(separator: LlamaText | V): LlamaText, toString(): string, toJSON(): LlamaTextJSON, - tokenize(tokenizer: Tokenizer): Token[] + tokenize(tokenizer: Tokenizer): Token[], + compare(other: LlamaText): boolean, + trimStart(): LlamaText, + trimEnd(): LlamaText, + includes(value: LlamaText): boolean }; export type LlamaTextValue = string | SpecialToken; @@ -49,6 +54,20 @@ LlamaText.fromJSON = function fromJSON(json: LlamaTextJSON) { }) ); }; +LlamaText.compare = function compare(a: LlamaText, b: LlamaText) { + if (!isLlamaText(a) || !isLlamaText(b)) + return false; + + if (a.values.length !== b.values.length) + return false; + + for (let i = 0; i < a.values.length; i++) { + if (!compareLlamaTextValues(a.values[i], b.values[i])) + return false; + } + + return true; +}; export class SpecialToken { public readonly value: string; @@ -61,8 +80,8 @@ export class SpecialToken { return this.value; } - public tokenize(tokenizer: Tokenizer): Token[] { - return tokenizer(this.value, true); + public tokenize(tokenizer: Tokenizer, trimLeadingSpace: boolean = false): Token[] { + return tokenizer(this.value, trimLeadingSpace ? "trimLeadingSpace" : true); } public toJSON(): LlamaTextSpecialTokenJSON { @@ -82,6 +101,16 @@ export class SpecialToken { public static isSpecialTokenJSON(value: LlamaTextJSONValue): value is LlamaTextSpecialTokenJSON { return value != null && typeof value === "object" && value.type === "specialToken"; } + + /** + * Wraps the value with a `SpecialToken` only if `shouldWrap` is true + */ + public static wrapIf(shouldWrap: boolean, value: string): SpecialToken | string { + if (shouldWrap) + return new SpecialToken(value); + else + return value; + } } export type BuiltinSpecialTokenValue = "BOS" | "EOS" | "NL"; @@ -160,9 +189,12 @@ const LlamaTextPrototypeFunctions: Partial = { const res: Token[] = []; for (const value of this.values) { - if (value instanceof SpecialToken) { + if (value instanceof BuiltinSpecialToken) { res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer)); textToTokenize = ""; + } else if (value instanceof SpecialToken) { + res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer, res.length > 0 || textToTokenize.length > 0)); + textToTokenize = ""; } else textToTokenize += value; } @@ -178,6 +210,99 @@ const LlamaTextPrototypeFunctions: Partial = { else return value satisfies LlamaTextJSONValue; }); + }, + compare(this: LlamaText, other: LlamaText) { + return LlamaText.compare(this, other); + }, + trimStart(this: LlamaText) { + const newValues = this.values.slice(); + + while (newValues.length > 0) { + const firstValue = newValues[0]; + + if (firstValue instanceof BuiltinSpecialToken) + break; + + if (firstValue instanceof SpecialToken) { + const newValue = firstValue.value.trimStart(); + if (newValue === "") { + newValues.shift(); + continue; + } else if (newValue !== firstValue.value) { + newValues[0] = new SpecialToken(newValue); + break; + } + + break; + } else if (typeof firstValue === "string") { + const newValue = firstValue.trimStart(); + if (newValue === "") { + newValues.shift(); + continue; + } else if (newValue !== firstValue) { + newValues[0] = newValue; + break; + } + + break; + } else + void (firstValue satisfies never); + } + + return createLlamaText(newValues); + }, + trimEnd(this: LlamaText) { + const newValues = this.values.slice(); + + while (newValues.length > 0) { + const lastValue = newValues[newValues.length - 1]; + + if (lastValue instanceof BuiltinSpecialToken) + break; + + if (lastValue instanceof SpecialToken) { + const newValue = lastValue.value.trimEnd(); + if (newValue === "") { + newValues.pop(); + continue; + } else if (newValue !== lastValue.value) { + newValues[newValues.length - 1] = new SpecialToken(newValue); + break; + } + + break; + } else if (typeof lastValue === "string") { + const newValue = lastValue.trimEnd(); + if (newValue === "") { + newValues.pop(); + continue; + } else if (newValue !== lastValue) { + newValues[newValues.length - 1] = newValue; + break; + } + + break; + } else + void (lastValue satisfies never); + } + + return createLlamaText(newValues); + }, + includes(this: LlamaText, value: LlamaText) { + for (let i = 0; i < this.values.length; i++) { + if (compareLlamaTextValues(this.values[i], value.values[0])) { + let j = 1; + for (; j < value.values.length; j++) { + if (!compareLlamaTextValues(this.values[i + j], value.values[j])) + break; + } + + if (j === value.values.length) + return true; + } + } + + return false; } }; @@ -235,6 +360,30 @@ function createLlamaText(history: readonly LlamaTextValue[]): LlamaText { writable: false, configurable: false, enumerable: false + }, + ["compare" satisfies keyof LlamaText]: { + value: LlamaTextPrototypeFunctions.compare, + writable: false, + configurable: false, + enumerable: false + }, + ["trimStart" satisfies keyof LlamaText]: { + value: LlamaTextPrototypeFunctions.trimStart, + writable: false, + configurable: false, + enumerable: false + }, + ["trimEnd" satisfies keyof LlamaText]: { + value: LlamaTextPrototypeFunctions.trimEnd, + writable: false, + configurable: false, + enumerable: false + }, + ["includes" satisfies keyof LlamaText]: { + value: LlamaTextPrototypeFunctions.includes, + writable: false, + configurable: false, + enumerable: false } }); @@ -269,10 +418,36 @@ function createHistoryFromStringsAndValues, item: LlamaTextValue) { + if (res.length === 0) { + res.push(item); + return res; + } + + const lastItem = res[res.length - 1]; + + if (lastItem instanceof BuiltinSpecialToken || item instanceof BuiltinSpecialToken) { + res.push(item); + return res; + } + + if (typeof lastItem === "string" && typeof item === "string") { + res[res.length - 1] += item; + return res; + } else if (lastItem instanceof SpecialToken && item instanceof SpecialToken) { + res[res.length - 1] = new SpecialToken(lastItem.value + item.value); + return res; + } + + res.push(item); + return res; + } + if (!isTemplateStringsArray(strings)) { return ([strings] as LlamaTextInputValue[]) .concat(values) - .reduce(addItemToRes, []); + .reduce(addItemToRes, []) + .reduce(squashAdjacentItems, []); } @@ -285,10 +460,22 @@ function createHistoryFromStringsAndValues { + const template1 = + "{{ bos_token }}" + + "{% for message in messages %}" + + "" + "{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}" + + "" + "" + "{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}" + + "" + "{% endif %}" + + "" + "{% if message['role'] == 'user' %}" + + "" + "" + "{{ '[INST] ' + message['content'] + ' [/INST]' }}" + + "" + "{% elif message['role'] == 'assistant' %}" + + "" + "" + "{{ message['content'] + eos_token}}" + + "" + "{% else %}" + + "" + "" + "{{ raise_exception('Only user and assistant roles are supported!') }}" + + "" + "{% endif %}" + + "{% endfor %}"; + const template2 = + "{% for message in messages %}" + + "" + "{% if message['role'] == 'user' %}" + + "" + "" + "{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}" + + "" + "{% elif message['role'] == 'system' %}" + + "" + "" + "{{ '<>\\n' + message['content'] + '\\n<>\\n\\n' }}" + + "" + "{% elif message['role'] == 'assistant' %}" + + "" + "" + "{{ ' ' + message['content'] + ' ' + eos_token }}" + + "" + "{% endif %}" + + "{% endfor %}"; + const template3 = + "{% for message in messages %}" + + "" + "{% if message['role'] == 'user' %}" + + "" + "" + "{{ bos_token + '[INST] ' + message['content'] + ' [/INST]' }}" + + "" + "{% elif message['role'] == 'assistant' %}" + + "" + "" + "{{ ' ' + message['content'] + ' ' + eos_token }}" + + "" + "{% endif %}" + + "{% endfor %}"; + + const conversationHistory: ChatHistoryItem[] = [{ + type: "system", + text: defaultChatSystemPrompt + }, { + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!"] + }]; + const conversationHistory2: ChatHistoryItem[] = [{ + type: "system", + text: defaultChatSystemPrompt + }, { + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!"] + }, { + type: "user", + text: "How are you?" + }, { + type: "model", + response: ["I'm good, how are you?"] + }]; + const conversationHistory3: ChatHistoryItem[] = [{ + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!"] + }, { + type: "user", + text: "How are you?" + }]; + const conversationHistoryWithFunctionCalls: ChatHistoryItem[] = [{ + type: "user", + text: "Hi there!" + }, { + type: "model", + response: ["Hello!", { + type: "functionCall", + name: "func2", + params: { + message: "Hello", + feeling: "good", + words: 1 + }, + result: { + yes: true, + message: "ok" + } + }] + }, { + type: "user", + text: "How are you?" + }]; + const exampleFunctions = { + func1: { + + }, + func2: { + params: { + type: "object", + properties: { + message: { + type: "string" + }, + feeling: { + enum: ["good", "bad"] + }, + words: { + type: "number" + } + } + } + }, + func3: { + description: "Some description here", + params: { + type: "array", + items: { + type: "string" + } + } + } + } as const; + + test("with system prompt support", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template2 + }); + const {contextText, stopGenerationTriggers} = chatWrapper.generateContextText(conversationHistory); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + expect(stopGenerationTriggers).toMatchInlineSnapshot(` + [ + LlamaText [ + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + ], + LlamaText [ + { + "type": "specialToken", + "value": " ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + ], + ] + `); + + const {contextText: contextText2} = chatWrapper.generateContextText(conversationHistory2); + + expect(contextText2.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + { + "type": "specialToken", + "value": " ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "How are you?", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "I'm good, how are you?", + ] + `); + + const {contextText: contextText3} = chatWrapper.generateContextText(conversationHistory); + const {contextText: contextText3WithOpenModelResponse} = chatWrapper.generateContextText([ + ...conversationHistory, + { + type: "model", + response: [] + } + ]); + + expect(contextText3.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + + expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello! + + ", + ] + `); + + const {contextText: contextText4} = chatWrapper.generateContextText(conversationHistory3); + + expect(contextText4.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + { + "type": "specialToken", + "value": " ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "How are you?", + { + "type": "specialToken", + "value": " [/INST]", + }, + ] + `); + }); + + test("without system prompt support", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template1 + }); + const {contextText} = chatWrapper.generateContextText(conversationHistory); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. + + Hi there!", + { + "type": "specialToken", + "value": " [/INST]", + }, + "Hello!", + ] + `); + }); + + test("without system prompt support with no exception from the template", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template3 + }); + const {contextText} = chatWrapper.generateContextText(conversationHistory); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. + + Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + }); + + test("without system prompt support with no exception from the template 2", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template2, + systemRoleName: "something1" + }); + const {contextText} = chatWrapper.generateContextText(conversationHistory); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. + + Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + }); + + test("without joining adjacent messages of the same type", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template2, + joinAdjacentMessagesOfTheSameType: false + }); + const {contextText} = chatWrapper.generateContextText([conversationHistory[0], ...conversationHistory]); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + <> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + }); + + test("functions", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template2 + }); + const {contextText} = chatWrapper.generateContextText(conversationHistory, { + availableFunctions: exampleFunctions + }); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "type": "specialToken", + "value": "<> + ", + }, + "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. + + The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. + Provided functions: + \`\`\` + function func1(); + + function func2(params: message: string, feeling: "good" | "bad", words: number); + + // Some description here + function func3(params: (string)[]); + \`\`\` + + Calling any of the provided functions can be done like this: + [[call: functionName({ someKey: "someValue" })]] + + After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. + The assistant calls the functions in advance before telling the user about the result", + { + "type": "specialToken", + "value": " + <> + + ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello!", + ] + `); + }); + + test("functions template", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template3, + functionCallMessageTemplate: [ + "[[call: {{functionName}}({{functionParams}})]]", + " [[result: {{functionCallResult}}]]" + ] + }); + const {contextText} = chatWrapper.generateContextText(conversationHistoryWithFunctionCalls, { + availableFunctions: exampleFunctions + }); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "System: The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. + Provided functions: + \`\`\` + function func1(); + + function func2(params: message: string, feeling: "good" | "bad", words: number); + + // Some description here + function func3(params: (string)[]); + \`\`\` + + Calling any of the provided functions can be done like this: + [[call: functionName({ someKey: "someValue" })]] + + After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. + The assistant calls the functions in advance before telling the user about the result + + Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello! + [[call: func2({"message":"Hello","feeling":"good","words":1})]] [[result: {"yes":true,"message":"ok"}]]", + { + "type": "specialToken", + "value": " ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "How are you?", + { + "type": "specialToken", + "value": " [/INST]", + }, + ] + `); + }); + + test("functions template 2", () => { + const chatWrapper = new JinjaTemplateChatWrapper({ + template: template3, + functionCallMessageTemplate: [ + "\nCall function: {{functionName}} with params {{functionParams}}.", + "\nFunction result: {{functionCallResult}}\n" + ] + }); + const {contextText} = chatWrapper.generateContextText(conversationHistoryWithFunctionCalls, { + availableFunctions: exampleFunctions + }); + + expect(contextText.values).toMatchInlineSnapshot(` + [ + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "System: The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. + Provided functions: + \`\`\` + function func1(); + + function func2(params: message: string, feeling: "good" | "bad", words: number); + + // Some description here + function func3(params: (string)[]); + \`\`\` + + Calling any of the provided functions can be done like this: + Call function: functionName with params { someKey: "someValue" }. + + After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. + The assistant calls the functions in advance before telling the user about the result + + Hi there!", + { + "type": "specialToken", + "value": " [/INST] ", + }, + "Hello! + + Call function: func2 with params {"message":"Hello","feeling":"good","words":1}. + Function result: {"yes":true,"message":"ok"} + ", + { + "type": "specialToken", + "value": " ", + }, + { + "builtin": true, + "type": "specialToken", + "value": "EOS", + }, + { + "builtin": true, + "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialToken", + "value": "[INST] ", + }, + "How are you?", + { + "type": "specialToken", + "value": " [/INST]", + }, + ] + `); + }); + + test("Fails when messages are not present in the render output", () => { + try { + new JinjaTemplateChatWrapper({ + template: template2, + userRoleName: "something1" + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(String(err)).toMatchInlineSnapshot('"Error: The provided Jinja template failed that sanity test: Error: Some input messages are not present in the generated Jinja template output"'); + } + }); + + test("Fails when messages are not present in the render output 2", () => { + try { + new JinjaTemplateChatWrapper({ + template: template2, + modelRoleName: "something1" + }); + expect.unreachable("Should have thrown an error"); + } catch (err) { + expect(String(err)).toMatchInlineSnapshot('"Error: The provided Jinja template failed that sanity test: Error: Some input messages are not present in the generated Jinja template output"'); + } + }); +}); diff --git a/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts b/test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts similarity index 97% rename from test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts rename to test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts index 7bb478ca..3e10a066 100644 --- a/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts @@ -1,6 +1,6 @@ import {describe, expect, test} from "vitest"; -import {ChatHistoryItem, LlamaChatWrapper} from "../../../src/index.js"; -import {defaultChatSystemPrompt} from "../../../src/config.js"; +import {ChatHistoryItem, LlamaChatWrapper} from "../../../../src/index.js"; +import {defaultChatSystemPrompt} from "../../../../src/config.js"; describe("LlamaChatWrapper", () => { diff --git a/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts b/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts new file mode 100644 index 00000000..6532d6c3 --- /dev/null +++ b/test/standalone/chatWrappers/utils/resolveChatWrapper.test.ts @@ -0,0 +1,207 @@ +import {describe, expect, test} from "vitest"; +import { + AlpacaChatWrapper, ChatMLChatWrapper, FalconChatWrapper, FunctionaryChatWrapper, GemmaChatWrapper, GeneralChatWrapper, LlamaChatWrapper, + resolveChatWrapper +} from "../../../../src/index.js"; + + +const alpacaJinjaTemplate = ` +{%- for message in messages %} + {%- if message['role'] == 'system' -%} + {{- message['content'] + '\\n\\n' -}} + {%- elif message['role'] == 'user' -%} + {{- '### Instruction:\\n' + message['content'] + '\\n\\n'-}} + {%- else -%} + {{- '### Response:\\n' + message['content'] + '\\n\\n' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- '### Response:\\n'-}} +{%- endif -%} +`.slice(1, -1); + +const chatMLJinjaTemplate = ` +{%- for message in messages %} + {%- if message['role'] == 'system' -%} + {{- '<|im_start|>system\\n' + message['content'].strip() + '<|im_end|>\\n' -}} + {%- elif message['role'] == 'user' -%} + {{- '<|im_start|>user\\n' + message['content'].strip() + '<|im_end|>\\n'-}} + {%- else -%} + {{- '<|im_start|>assistant\\n' + message['content'] + '<|im_end|>\\n' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- '<|im_start|>assistant\\n'-}} +{%- endif -%} +`.slice(1, -1); + +const falconJinjaTemplate = ` +{%- if messages[0]['role'] == 'system' %} + {%- set loop_messages = messages[1:] %} + {%- set system_message = messages[0]['content'] %} +{%- else %} + {%- set loop_messages = messages %} + {%- set system_message = '' %} +{%- endif %} +{%- for message in loop_messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if loop.index0 == 0 %} + {{- system_message.strip() }} + {%- endif %} + {%- if message['role'] == 'user' %} + {{- '\n\nUser: ' + message['content'].strip() }} + {%- elif message['role'] == 'assistant' %} + {{- '\n\nAssistant: ' + message['content'].strip() }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '\n\nAssistant:' }} +{%- endif %} +`.slice(1, -1); + +const funcationaryJinjaTemplate = "{% for message in messages %}\n{% if message['role'] == 'user' or message['role'] == 'system' %}\n{{ '<|from|>' + message['role'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% elif message['role'] == 'tool' %}\n{{ '<|from|>' + message['name'] + '\n<|recipient|>all\n<|content|>' + message['content'] + '\n' }}{% else %}\n{% set contain_content='no'%}\n{% if message['content'] is not none %}\n{{ '<|from|>assistant\n<|recipient|>all\n<|content|>' + message['content'] }}{% set contain_content='yes'%}\n{% endif %}\n{% if 'tool_calls' in message and message['tool_calls'] is not none %}\n{% for tool_call in message['tool_calls'] %}\n{% set prompt='<|from|>assistant\n<|recipient|>' + tool_call['function']['name'] + '\n<|content|>' + tool_call['function']['arguments'] %}\n{% if loop.index == 1 and contain_content == \"no\" %}\n{{ prompt }}{% else %}\n{{ '\n' + prompt}}{% endif %}\n{% endfor %}\n{% endif %}\n{{ '<|stop|>\n' }}{% endif %}\n{% endfor %}\n{% if add_generation_prompt %}{{ '<|from|>assistant\n<|recipient|>' }}{% endif %}"; + +const gemmaJinjaTemplate = ` +{%- if messages[0]['role'] == 'system' %} + {{- raise_exception('System role not supported') }} +{%- endif %} +{%- for message in messages %} + {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %} + {{- raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }} + {%- endif %} + {%- if (message['role'] == 'assistant') %} + {%- set role = 'model' %} + {%- else %} + {%- set role = message['role'] %} + {%- endif %} + {{- '' + role + '\n' + message['content'] | trim + '\n' }} +{%- endfor %} +{%- if add_generation_prompt %} + {{- 'model\n' }} +{%- endif %} +`.slice(1, -1); + +const generalJinjaTemplate = ` +{%- for message in messages %} + {%- if message['role'] == 'system' -%} + {{- message['content'] + '\\n\\n' -}} + {%- elif message['role'] == 'user' -%} + {{- '### Human\\n' + message['content'] + '\\n\\n'-}} + {%- else -%} + {{- '### Assistant\\n' + message['content'] + '\\n\\n' -}} + {%- endif -%} +{%- endfor -%} +{%- if add_generation_prompt -%} + {{- '### Assistant\\n'-}} +{%- endif -%} +`.slice(1, -1); + +const llamaChatJinjaTemplate = ` +{%- set ns = namespace(found=false) -%} +{%- for message in messages -%} + {%- if message['role'] == 'system' -%} + {%- set ns.found = true -%} + {%- endif -%} +{%- endfor -%} +{%- if not ns.found -%} + {{- '[INST] <>\n' + 'Answer the questions.' + '\n<>\n\n' -}} +{%- endif %} +{%- for message in messages %} + {%- if message['role'] == 'system' -%} + {{- '[INST] <>\n' + message['content'] + '\n<>\n\n' -}} + {%- elif message['role'] == 'user' -%} + {{- message['content'] + ' [/INST] '-}} + {%- else -%} + {{- message['content'] + eos_token + bos_token + '[INST] ' -}} + {%- endif -%} +{%- endfor -%} +`.slice(1, -1); + + +describe("resolveChatWrapper", () => { + test("should resolve to specialized AlpacaChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: alpacaJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(AlpacaChatWrapper); + }); + + test("should resolve to specialized ChatMLChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: chatMLJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(ChatMLChatWrapper); + }); + + test("should resolve to specialized FalconChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: falconJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(FalconChatWrapper); + }); + + test("should resolve to specialized FunctionaryChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: funcationaryJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(FunctionaryChatWrapper); + }); + + test("should resolve to specialized GemmaChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: gemmaJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(GemmaChatWrapper); + }); + + test("should resolve to specialized GeneralChatWrapper", () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: generalJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(GeneralChatWrapper); + }); + + test("should resolve to specialized LlamaChatWrapper", async () => { + const chatWrapper = resolveChatWrapper({ + customWrapperSettings: { + jinjaTemplate: { + template: llamaChatJinjaTemplate + } + }, + fallbackToOtherWrappersOnJinjaError: false + }); + expect(chatWrapper).to.be.instanceof(LlamaChatWrapper); + }); +}); diff --git a/test/utils/helpers/llamaTextSerializer.ts b/test/utils/helpers/llamaTextSerializer.ts new file mode 100644 index 00000000..cf7d70e7 --- /dev/null +++ b/test/utils/helpers/llamaTextSerializer.ts @@ -0,0 +1,11 @@ +import {SnapshotSerializer} from "vitest"; +import {isLlamaText} from "../../../src/index.js"; + +export default { + serialize(value, config, indentation, depth, refs, printer) { + return "LlamaText " + printer(value.values, config, indentation, depth, refs); + }, + test(value) { + return isLlamaText(value); + } +} satisfies SnapshotSerializer; diff --git a/vitest.config.ts b/vitest.config.ts index d7ec215d..ff41ab8f 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -10,6 +10,7 @@ export default defineConfig({ minThreads: 1, maxThreads: 1 } - } + }, + snapshotSerializers: ["./test/utils/helpers/llamaTextSerializer.ts"] } }); From 517d8ee63cd7ee750cc5de69f04806f1b4718ad8 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 20:30:16 +0300 Subject: [PATCH 34/52] feat: improve `resolveChatWrapper` resolution algorithm --- src/chatWrappers/AlpacaChatWrapper.ts | 2 +- src/chatWrappers/FalconChatWrapper.ts | 2 +- src/chatWrappers/GeneralChatWrapper.ts | 2 +- src/chatWrappers/LlamaChatWrapper.ts | 13 +++++---- src/chatWrappers/utils/resolveChatWrapper.ts | 30 ++++++++++++++++---- 5 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/chatWrappers/AlpacaChatWrapper.ts b/src/chatWrappers/AlpacaChatWrapper.ts index 09bb6297..608aadb7 100644 --- a/src/chatWrappers/AlpacaChatWrapper.ts +++ b/src/chatWrappers/AlpacaChatWrapper.ts @@ -31,7 +31,7 @@ export class AlpacaChatWrapper extends GeneralChatWrapper { /** @internal */ public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { - return [{ + return [{}, { allowSpecialTokensInTitles: true }] satisfies Partial[0]>[]; } diff --git a/src/chatWrappers/FalconChatWrapper.ts b/src/chatWrappers/FalconChatWrapper.ts index 36096ccf..f2f6825e 100644 --- a/src/chatWrappers/FalconChatWrapper.ts +++ b/src/chatWrappers/FalconChatWrapper.ts @@ -155,7 +155,7 @@ export class FalconChatWrapper extends ChatWrapper { /** @internal */ public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { - return [{ + return [{}, { allowSpecialTokensInTitles: true }] satisfies Partial[0]>[]; } diff --git a/src/chatWrappers/GeneralChatWrapper.ts b/src/chatWrappers/GeneralChatWrapper.ts index 7f1e1daa..7b38fb7b 100644 --- a/src/chatWrappers/GeneralChatWrapper.ts +++ b/src/chatWrappers/GeneralChatWrapper.ts @@ -174,7 +174,7 @@ export class GeneralChatWrapper extends ChatWrapper { /** @internal */ public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { - return [{ + return [{}, { allowSpecialTokensInTitles: true }] satisfies Partial[0]>[]; } diff --git a/src/chatWrappers/LlamaChatWrapper.ts b/src/chatWrappers/LlamaChatWrapper.ts index f40a631e..eed5f34d 100644 --- a/src/chatWrappers/LlamaChatWrapper.ts +++ b/src/chatWrappers/LlamaChatWrapper.ts @@ -8,16 +8,17 @@ export class LlamaChatWrapper extends ChatWrapper { /** @internal */ private readonly _addSpaceBeforeEos: boolean; - /** @internal */ public constructor({ - _addSpaceBeforeEos = true + addSpaceBeforeEos = true }: { - /** @internal */ - _addSpaceBeforeEos?: boolean + /** + * Default to `true` + */ + addSpaceBeforeEos?: boolean } = {}) { super(); - this._addSpaceBeforeEos = _addSpaceBeforeEos; + this._addSpaceBeforeEos = addSpaceBeforeEos; } public override generateContextText(history: readonly ChatHistoryItem[], {availableFunctions, documentFunctionParams}: { @@ -120,7 +121,7 @@ export class LlamaChatWrapper extends ChatWrapper { /** @internal */ public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { return [{}, { - _addSpaceBeforeEos: false + addSpaceBeforeEos: false }] satisfies Partial[0]>[]; } } diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts index 6502f648..2fc76ebc 100644 --- a/src/chatWrappers/utils/resolveChatWrapper.ts +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -92,11 +92,17 @@ export function resolveChatWrapper({ warningLogs = true, fallbackToOtherWrappersOnJinjaError = true }: ResolveChatWrapperOptions) { - function createSpecializedChatWrapper(specializedChatWrapper: typeof chatWrappers[SpecializedChatWrapperTypeName]) { + function createSpecializedChatWrapper( + specializedChatWrapper: T, + defaultSettings: ConstructorParameters[0] = {} + ) { const chatWrapperConfigType = chatWrapperToConfigType.get(specializedChatWrapper) as SpecializedChatWrapperTypeName; const chatWrapperSettings = customWrapperSettings?.[chatWrapperConfigType]; - return new (specializedChatWrapper as any)(chatWrapperSettings); + return new (specializedChatWrapper as any)({ + ...(defaultSettings ?? {}), + ...(chatWrapperSettings ?? {}) + }); } if (type !== "auto" && type != null) { @@ -159,13 +165,14 @@ export function resolveChatWrapper({ testOptionConfigurations.push({} as any); for (const testConfiguration of testOptionConfigurations) { - const chatWrapper = new (Wrapper as any)({ + const testChatWrapperSettings = { ...(wrapperSettings ?? {}), ...(testConfiguration ?? {}) - }); + }; + const chatWrapper = new (Wrapper as any)(testChatWrapperSettings); if (isJinjaTemplateEquivalentToSpecializedChatWrapper(jinjaTemplateChatWrapperOptions, chatWrapper, tokenizer)) - return new (Wrapper as any)(wrapperSettings ?? {}); + return new (Wrapper as any)(testChatWrapperSettings); } } @@ -179,6 +186,19 @@ export function resolveChatWrapper({ } } + // try to find a pattern in the Jinja template to resolve to a specialized chat wrapper, + // with a logic similar to `llama.cpp`'s `llama_chat_apply_template_internal` function + if (modelJinjaTemplate != null && modelJinjaTemplate.trim() !== "") { + if (modelJinjaTemplate.includes("<|im_start|>")) + return createSpecializedChatWrapper(ChatMLChatWrapper); + else if (modelJinjaTemplate.includes("[INST]")) + return createSpecializedChatWrapper(LlamaChatWrapper, { + addSpaceBeforeEos: modelJinjaTemplate.includes("' ' + eos_token") + }); + else if (modelJinjaTemplate.includes("")) + return createSpecializedChatWrapper(GemmaChatWrapper); + } + if (filename != null) { const {name, subType, fileType} = parseModelFileName(filename); From a8c677ac347d9012a947490c4adc38c09ee764e3 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 21:05:33 +0300 Subject: [PATCH 35/52] refactor: rename `SpecialToken` to `SpecialTokensText` and separate `BuiltinSpecialToken` inheritance from `SpecialToken` --- .vitepress/config.ts | 2 +- src/chatWrappers/ChatMLChatWrapper.ts | 16 +-- src/chatWrappers/FalconChatWrapper.ts | 20 ++-- src/chatWrappers/FunctionaryChatWrapper.ts | 90 +++++++------- src/chatWrappers/GemmaChatWrapper.ts | 12 +- src/chatWrappers/GeneralChatWrapper.ts | 34 +++--- src/chatWrappers/LlamaChatWrapper.ts | 10 +- .../generic/JinjaTemplateChatWrapper.ts | 6 +- .../generic/TemplateChatWrapper.ts | 24 ++-- src/evaluator/LlamaChat/LlamaChat.ts | 2 +- src/index.ts | 7 +- src/utils/LlamaText.ts | 110 ++++++++++-------- src/utils/StopGenerationDetector.ts | 6 +- 13 files changed, 181 insertions(+), 158 deletions(-) diff --git a/.vitepress/config.ts b/.vitepress/config.ts index b6cff6ac..494b05f2 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -326,7 +326,7 @@ function orderClasses(sidebar: typeof typedocSidebar) { items: [] }; (classes.items as DefaultTheme.SidebarItem[]).push(LlamaTextGroup); - const LlamaTextGroupItemsOrder = ["SpecialToken", "BuiltinSpecialToken"]; + const LlamaTextGroupItemsOrder = ["SpecialTokensText", "BuiltinSpecialToken"]; groupItems( classes.items, diff --git a/src/chatWrappers/ChatMLChatWrapper.ts b/src/chatWrappers/ChatMLChatWrapper.ts index 18fb57ed..4af57816 100644 --- a/src/chatWrappers/ChatMLChatWrapper.ts +++ b/src/chatWrappers/ChatMLChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://github.com/openai/openai-python/blob/120d225b91a8453e15240a49fb1c6794d8119326/chatml.md export class ChatMLChatWrapper extends ChatWrapper { @@ -78,28 +78,28 @@ export class ChatMLChatWrapper extends ChatWrapper { (system.length === 0) ? LlamaText([]) : LlamaText([ - new SpecialToken("<|im_start|>system\n"), + new SpecialTokensText("<|im_start|>system\n"), system, - new SpecialToken("<|im_end|>\n") + new SpecialTokensText("<|im_end|>\n") ]), (user.length === 0) ? LlamaText([]) : LlamaText([ - new SpecialToken("<|im_start|>user\n"), + new SpecialTokensText("<|im_start|>user\n"), user, - new SpecialToken("<|im_end|>\n") + new SpecialTokensText("<|im_end|>\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - new SpecialToken("<|im_start|>assistant\n"), + new SpecialTokensText("<|im_start|>assistant\n"), model, isLastItem ? LlamaText([]) - : new SpecialToken("<|im_end|>\n") + : new SpecialTokensText("<|im_end|>\n") ]) ]); }) @@ -109,7 +109,7 @@ export class ChatMLChatWrapper extends ChatWrapper { contextText, stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), - LlamaText(new SpecialToken("<|im_end|>")), + LlamaText(new SpecialTokensText("<|im_end|>")), LlamaText("<|im_end|>") ] }; diff --git a/src/chatWrappers/FalconChatWrapper.ts b/src/chatWrappers/FalconChatWrapper.ts index f2f6825e..9cef38ae 100644 --- a/src/chatWrappers/FalconChatWrapper.ts +++ b/src/chatWrappers/FalconChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {LlamaText, BuiltinSpecialToken, SpecialToken} from "../utils/LlamaText.js"; +import {LlamaText, BuiltinSpecialToken, SpecialTokensText} from "../utils/LlamaText.js"; export class FalconChatWrapper extends ChatWrapper { public readonly wrapperName: string = "Falcon"; @@ -105,27 +105,27 @@ export class FalconChatWrapper extends ChatWrapper { : LlamaText([ isFirstItem ? LlamaText([]) - : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._middleSystemMessageTitle}: `), + : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._middleSystemMessageTitle}: `), system, - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (user.length === 0) ? LlamaText([]) : LlamaText([ - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._userMessageTitle}: `), + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._userMessageTitle}: `), user, - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `${this._modelResponseTitle}: `), + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `${this._modelResponseTitle}: `), model, isLastItem ? LlamaText([]) - : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]) ]); }) @@ -144,9 +144,9 @@ export class FalconChatWrapper extends ChatWrapper { !this._allowSpecialTokensInTitles ? [] : [ - LlamaText(new SpecialToken(`\n${this._userMessageTitle}:`)), - LlamaText(new SpecialToken(`\n${this._modelResponseTitle}:`)), - LlamaText(new SpecialToken(`\n${this._middleSystemMessageTitle}:`)) + LlamaText(new SpecialTokensText(`\n${this._userMessageTitle}:`)), + LlamaText(new SpecialTokensText(`\n${this._modelResponseTitle}:`)), + LlamaText(new SpecialTokensText(`\n${this._middleSystemMessageTitle}:`)) ] ) ] diff --git a/src/chatWrappers/FunctionaryChatWrapper.ts b/src/chatWrappers/FunctionaryChatWrapper.ts index 8c4e65cf..da334380 100644 --- a/src/chatWrappers/FunctionaryChatWrapper.ts +++ b/src/chatWrappers/FunctionaryChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions, isChatModelResponseFunctionCall} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; import {getTypeScriptTypeStringForGbnfJsonSchema} from "../utils/getTypeScriptTypeStringForGbnfJsonSchema.js"; // source: https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v2.txt @@ -53,20 +53,20 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ isFirstItem ? LlamaText([]) - : new SpecialToken("\n"), - new SpecialToken("<|from|>system\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>"), + : new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>system\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>"), item.text ]); } else if (item.type === "user") { return LlamaText([ isFirstItem ? LlamaText([]) - : new SpecialToken("\n"), - new SpecialToken("<|from|>user\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>"), + : new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>user\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>"), item.text ]); } else if (item.type === "model") { @@ -74,10 +74,10 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ isFirstItem ? LlamaText([]) - : new SpecialToken("\n"), - new SpecialToken("<|from|>assistant\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>") + : new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>assistant\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>") ]); return LlamaText( @@ -89,14 +89,14 @@ export class FunctionaryChatWrapper extends ChatWrapper { return LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) - : new SpecialToken("\n"), - new SpecialToken("<|from|>assistant\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>"), + : new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>assistant\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>"), response, (isLastResponse && isLastItem) ? "" - : new SpecialToken("<|stop|>") + : new SpecialTokensText("<|stop|>") ]); else if (isChatModelResponseFunctionCall(response)) { return LlamaText([ @@ -105,20 +105,20 @@ export class FunctionaryChatWrapper extends ChatWrapper { : LlamaText([ (isFirstItem && isFirstResponse) ? LlamaText([]) - : new SpecialToken("\n"), + : new SpecialTokensText("\n"), - new SpecialToken("<|from|>assistant\n"), - new SpecialToken("<|recipient|>"), response.name, new SpecialToken("\n"), - new SpecialToken("<|content|>"), + new SpecialTokensText("<|from|>assistant\n"), + new SpecialTokensText("<|recipient|>"), response.name, new SpecialTokensText("\n"), + new SpecialTokensText("<|content|>"), response.params === undefined ? "" : JSON.stringify(response.params), - new SpecialToken("<|stop|>"), + new SpecialTokensText("<|stop|>"), - new SpecialToken("\n"), - new SpecialToken("<|from|>"), response.name, new SpecialToken("\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>"), + new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>"), response.name, new SpecialTokensText("\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>"), response.result === undefined ? "" // "void" : JSON.stringify(response.result) @@ -127,10 +127,10 @@ export class FunctionaryChatWrapper extends ChatWrapper { hasFunctions ? LlamaText([]) : LlamaText([ - new SpecialToken("\n"), - new SpecialToken("<|from|>assistant\n"), - new SpecialToken("<|recipient|>all\n"), - new SpecialToken("<|content|>") + new SpecialTokensText("\n"), + new SpecialTokensText("<|from|>assistant\n"), + new SpecialTokensText("<|recipient|>all\n"), + new SpecialTokensText("<|content|>") ]) ]); } @@ -151,18 +151,18 @@ export class FunctionaryChatWrapper extends ChatWrapper { contextText, stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), - LlamaText(new SpecialToken("<|stop|>")), + LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), LlamaText("\n<|from|>user"), LlamaText("\n<|from|>assistant"), LlamaText("\n<|from|>system"), - LlamaText(new SpecialToken(" <|stop|>")), - LlamaText(new SpecialToken("<|stop|>")), - LlamaText(new SpecialToken("\n<|from|>user")), - LlamaText(new SpecialToken("\n<|from|>assistant")), - LlamaText(new SpecialToken("\n<|from|>system")) + LlamaText(new SpecialTokensText(" <|stop|>")), + LlamaText(new SpecialTokensText("<|stop|>")), + LlamaText(new SpecialTokensText("\n<|from|>user")), + LlamaText(new SpecialTokensText("\n<|from|>assistant")), + LlamaText(new SpecialTokensText("\n<|from|>system")) ] }; } @@ -171,29 +171,29 @@ export class FunctionaryChatWrapper extends ChatWrapper { contextText, stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), - LlamaText(new SpecialToken("<|stop|>")), + LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), LlamaText("\n<|from|>user"), - LlamaText(new SpecialToken(" <|stop|>")), - LlamaText(new SpecialToken("<|stop|>")), - LlamaText(new SpecialToken("\n<|from|>user")) + LlamaText(new SpecialTokensText(" <|stop|>")), + LlamaText(new SpecialTokensText("<|stop|>")), + LlamaText(new SpecialTokensText("\n<|from|>user")) ], ignoreStartText: [ LlamaText("\n<|from|>assistant\n<|recipient|>all\n<|content|>"), - LlamaText(new SpecialToken("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), + LlamaText(new SpecialTokensText("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>"), - LlamaText(new SpecialToken("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) + LlamaText(new SpecialTokensText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) ], functionCall: { initiallyEngaged: true, disengageInitiallyEngaged: [ LlamaText("\n<|from|>assistant\n<|recipient|>all\n<|content|>"), - LlamaText(new SpecialToken("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), + LlamaText(new SpecialTokensText("\n<|from|>assistant\n<|recipient|>all\n<|content|>")), LlamaText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>"), - LlamaText(new SpecialToken("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) + LlamaText(new SpecialTokensText("\n\n<|from|>assistant\n<|recipient|>all\n<|content|>")) ] } }; diff --git a/src/chatWrappers/GemmaChatWrapper.ts b/src/chatWrappers/GemmaChatWrapper.ts index 4aab2dc3..7d99bd06 100644 --- a/src/chatWrappers/GemmaChatWrapper.ts +++ b/src/chatWrappers/GemmaChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://ai.google.dev/gemma/docs/formatting // source: https://www.promptingguide.ai/models/gemma @@ -87,20 +87,20 @@ export class GemmaChatWrapper extends ChatWrapper { (user.length === 0) ? LlamaText([]) : LlamaText([ - new SpecialToken("user\n"), + new SpecialTokensText("user\n"), user, - new SpecialToken("\n") + new SpecialTokensText("\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - new SpecialToken("model\n"), + new SpecialTokensText("model\n"), model, isLastItem ? LlamaText([]) - : new SpecialToken("\n") + : new SpecialTokensText("\n") ]) ]); }) @@ -110,7 +110,7 @@ export class GemmaChatWrapper extends ChatWrapper { contextText, stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), - LlamaText(new SpecialToken("\n")), + LlamaText(new SpecialTokensText("\n")), LlamaText("") ] }; diff --git a/src/chatWrappers/GeneralChatWrapper.ts b/src/chatWrappers/GeneralChatWrapper.ts index 7b38fb7b..7aef46a6 100644 --- a/src/chatWrappers/GeneralChatWrapper.ts +++ b/src/chatWrappers/GeneralChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; export class GeneralChatWrapper extends ChatWrapper { public readonly wrapperName: string = "General"; @@ -106,27 +106,27 @@ export class GeneralChatWrapper extends ChatWrapper { : LlamaText([ isFirstItem ? LlamaText([]) - : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._middleSystemMessageTitle}\n`), + : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `### ${this._middleSystemMessageTitle}\n`), system, - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (user.length === 0) ? LlamaText([]) : LlamaText([ - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._userMessageTitle}\n`), + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `### ${this._userMessageTitle}\n`), user, - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]), (model.length === 0 && !isLastItem) ? LlamaText([]) : LlamaText([ - SpecialToken.wrapIf(this._allowSpecialTokensInTitles, `### ${this._modelResponseTitle}\n`), + SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, `### ${this._modelResponseTitle}\n`), model, isLastItem ? LlamaText([]) - : SpecialToken.wrapIf(this._allowSpecialTokensInTitles, "\n\n") + : SpecialTokensText.wrapIf(this._allowSpecialTokensInTitles, "\n\n") ]) ]); }) @@ -136,7 +136,7 @@ export class GeneralChatWrapper extends ChatWrapper { contextText, stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), - LlamaText(new SpecialToken("")), + LlamaText(new SpecialTokensText("")), LlamaText(""), LlamaText(`### ${this._userMessageTitle}`), @@ -155,17 +155,17 @@ export class GeneralChatWrapper extends ChatWrapper { !this._allowSpecialTokensInTitles ? [] : [ - LlamaText(new SpecialToken(`### ${this._userMessageTitle}`)), - LlamaText(new SpecialToken(`\n### ${this._userMessageTitle}`)), - LlamaText(new SpecialToken(`\n\n### ${this._userMessageTitle}`)), + LlamaText(new SpecialTokensText(`### ${this._userMessageTitle}`)), + LlamaText(new SpecialTokensText(`\n### ${this._userMessageTitle}`)), + LlamaText(new SpecialTokensText(`\n\n### ${this._userMessageTitle}`)), - LlamaText(new SpecialToken(`### ${this._modelResponseTitle}`)), - LlamaText(new SpecialToken(`\n### ${this._modelResponseTitle}`)), - LlamaText(new SpecialToken(`\n\n### ${this._modelResponseTitle}`)), + LlamaText(new SpecialTokensText(`### ${this._modelResponseTitle}`)), + LlamaText(new SpecialTokensText(`\n### ${this._modelResponseTitle}`)), + LlamaText(new SpecialTokensText(`\n\n### ${this._modelResponseTitle}`)), - LlamaText(new SpecialToken(`### ${this._middleSystemMessageTitle}`)), - LlamaText(new SpecialToken(`\n### ${this._middleSystemMessageTitle}`)), - LlamaText(new SpecialToken(`\n\n### ${this._middleSystemMessageTitle}`)) + LlamaText(new SpecialTokensText(`### ${this._middleSystemMessageTitle}`)), + LlamaText(new SpecialTokensText(`\n### ${this._middleSystemMessageTitle}`)), + LlamaText(new SpecialTokensText(`\n\n### ${this._middleSystemMessageTitle}`)) ] ) ] diff --git a/src/chatWrappers/LlamaChatWrapper.ts b/src/chatWrappers/LlamaChatWrapper.ts index eed5f34d..55574477 100644 --- a/src/chatWrappers/LlamaChatWrapper.ts +++ b/src/chatWrappers/LlamaChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 export class LlamaChatWrapper extends ChatWrapper { @@ -87,16 +87,16 @@ export class LlamaChatWrapper extends ChatWrapper { (system.length === 0 && user.length === 0) ? LlamaText([]) : LlamaText([ - new SpecialToken("[INST] "), + new SpecialTokensText("[INST] "), system.length === 0 ? LlamaText([]) : LlamaText([ - new SpecialToken("<>\n"), + new SpecialTokensText("<>\n"), system, - new SpecialToken("\n<>\n\n") + new SpecialTokensText("\n<>\n\n") ]), user, - new SpecialToken(" [/INST] ") + new SpecialTokensText(" [/INST] ") ]), model, this._addSpaceBeforeEos diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts index 0ffd805c..1902eb1f 100644 --- a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -1,7 +1,7 @@ import {Template} from "@huggingface/jinja"; import {splitText} from "lifecycle-utils"; import {ChatHistoryItem, ChatModelFunctions, ChatUserMessage} from "../../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialToken} from "../../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../../utils/LlamaText.js"; import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; @@ -287,7 +287,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { const contextText = LlamaText( splitJinjaParts.map((part) => { if (typeof part === "string") - return new SpecialToken(part); // things that are not message content can be tokenized with special tokens + return new SpecialTokensText(part); // things that are not message content can be tokenized with special tokens const message = idToContent.get(part.separator); @@ -314,7 +314,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { LlamaText( stopGenerationJinjaParts.map((part) => { if (typeof part === "string") - return new SpecialToken(part); + return new SpecialTokensText(part); const message = idToContent.get(part.separator); diff --git a/src/chatWrappers/generic/TemplateChatWrapper.ts b/src/chatWrappers/generic/TemplateChatWrapper.ts index 16223642..f63a87c0 100644 --- a/src/chatWrappers/generic/TemplateChatWrapper.ts +++ b/src/chatWrappers/generic/TemplateChatWrapper.ts @@ -1,5 +1,5 @@ import {ChatHistoryItem, ChatModelFunctions} from "../../types.js"; -import {BuiltinSpecialToken, LlamaText, LlamaTextValue, SpecialToken} from "../../utils/LlamaText.js"; +import {BuiltinSpecialToken, LlamaText, LlamaTextValue, SpecialTokensText} from "../../utils/LlamaText.js"; import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; import {parseTextTemplate} from "../../utils/parseTextTemplate.js"; import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; @@ -155,9 +155,9 @@ export class TemplateChatWrapper extends ChatWrapper { const getHistoryItem = (role: "system" | "user" | "model", text: string, prefix?: string | null) => { const {roleNamePrefix, messagePrefix, messageSuffix} = this._parsedChatHistoryTemplate; return LlamaText([ - new SpecialToken((prefix ?? "") + roleNamePrefix + role + messagePrefix), + new SpecialTokensText((prefix ?? "") + roleNamePrefix + role + messagePrefix), text, - new SpecialToken(messageSuffix) + new SpecialTokensText(messageSuffix) ]); }; @@ -169,12 +169,14 @@ export class TemplateChatWrapper extends ChatWrapper { const res = LlamaText([ isFirstItem ? system.length === 0 - ? new SpecialToken((this._parsedChatTemplate.systemPromptPrefix ?? "") + this._parsedChatTemplate.historyPrefix) + ? new SpecialTokensText( + (this._parsedChatTemplate.systemPromptPrefix ?? "") + this._parsedChatTemplate.historyPrefix + ) : this._parsedChatTemplate.systemPromptPrefix != null ? LlamaText([ - new SpecialToken(this._parsedChatTemplate.systemPromptPrefix), + new SpecialTokensText(this._parsedChatTemplate.systemPromptPrefix), system, - new SpecialToken(this._parsedChatTemplate.historyPrefix) + new SpecialTokensText(this._parsedChatTemplate.historyPrefix) ]) : getHistoryItem("system", system, this._parsedChatTemplate.historyPrefix) : system.length === 0 @@ -191,21 +193,21 @@ export class TemplateChatWrapper extends ChatWrapper { : !isLastItem ? getHistoryItem("model", model) : LlamaText([ - new SpecialToken(this._parsedChatTemplate.completionPrefix), + new SpecialTokensText(this._parsedChatTemplate.completionPrefix), model ]) ]); return LlamaText( res.values.reduce((res, value) => { - if (value instanceof SpecialToken) { + if (value instanceof SpecialTokensText) { const lastItem = res[res.length - 1]; - if (lastItem == null || !(lastItem instanceof SpecialToken)) + if (lastItem == null || !(lastItem instanceof SpecialTokensText)) return res.concat([value]); return res.slice(0, -1).concat([ - new SpecialToken(lastItem.value + value.value) + new SpecialTokensText(lastItem.value + value.value) ]); } @@ -220,7 +222,7 @@ export class TemplateChatWrapper extends ChatWrapper { stopGenerationTriggers: [ LlamaText(new BuiltinSpecialToken("EOS")), LlamaText(this._parsedChatTemplate.completionSuffix), - LlamaText(new SpecialToken(this._parsedChatTemplate.completionSuffix)) + LlamaText(new SpecialTokensText(this._parsedChatTemplate.completionSuffix)) ] }; } diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index a6c6c9fe..67f005c3 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -1146,7 +1146,7 @@ async function getContextWindow({ return { history: compressedHistory, stopGenerationTriggers, - tokens: contextText .tokenize(model.tokenize), + tokens: contextText.tokenize(model.tokenize), newResolvedHistory: resolvedHistory, newHistoryCompressionMetadata: metadata, ignoreStartText: ignoreStartText ?? [], diff --git a/src/index.ts b/src/index.ts index 555db6ea..3412fb42 100644 --- a/src/index.ts +++ b/src/index.ts @@ -49,8 +49,8 @@ import { type ResolveChatWrapperOptions } from "./chatWrappers/utils/resolveChatWrapper.js"; import { - LlamaText, SpecialToken, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, - type LlamaTextSpecialTokenJSON + LlamaText, SpecialTokensText, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, + type LlamaTextSpecialTokensTextJSON, type LlamaTextSpecialTokenJSON } from "./utils/LlamaText.js"; import {appendUserMessageToChatHistory} from "./utils/appendUserMessageToChatHistory.js"; import {getModuleVersion} from "./utils/getModuleVersion.js"; @@ -152,12 +152,13 @@ export { templateChatWrapperTypeNames, type TemplateChatWrapperTypeName, LlamaText, - SpecialToken, + SpecialTokensText, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, + type LlamaTextSpecialTokensTextJSON, type LlamaTextSpecialTokenJSON, appendUserMessageToChatHistory, getModuleVersion, diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index 84c87ac4..070b23d9 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -28,15 +28,16 @@ export type LlamaText = { includes(value: LlamaText): boolean }; -export type LlamaTextValue = string | SpecialToken; +export type LlamaTextValue = string | SpecialTokensText | BuiltinSpecialToken; export type LlamaTextJSON = Array; -export type LlamaTextJSONValue = string | LlamaTextSpecialTokenJSON; -export type LlamaTextSpecialTokenJSON = {type: "specialToken", value: string, builtin?: true}; +export type LlamaTextJSONValue = string | LlamaTextSpecialTokensTextJSON | LlamaTextSpecialTokenJSON; +export type LlamaTextSpecialTokensTextJSON = {type: "specialTokensText", value: string}; +export type LlamaTextSpecialTokenJSON = {type: "specialToken", value: string}; export const LlamaText: LlamaTextClass = function LlamaText( - strings: TemplateStringsArray | string | string[] | SpecialToken | LlamaText | LlamaText[], - ...values: (SpecialToken | string | string[] | number | boolean | LlamaText | LlamaText[])[] + strings: TemplateStringsArray | string | string[] | SpecialTokensText | BuiltinSpecialToken | LlamaText | LlamaText[], + ...values: (SpecialTokensText | BuiltinSpecialToken | string | string[] | number | boolean | LlamaText | LlamaText[])[] ) { return createLlamaText(createHistoryFromStringsAndValues(strings, values)); } as LlamaTextClass; @@ -45,8 +46,10 @@ LlamaText.fromJSON = function fromJSON(json: LlamaTextJSON) { json.map((value) => { if (typeof value === "string") return value; - else if (SpecialToken.isSpecialTokenJSON(value)) - return SpecialToken.fromJSON(value); + else if (BuiltinSpecialToken.isSpecialTokenJSON(value)) + return BuiltinSpecialToken.fromJSON(value); + else if (SpecialTokensText.isSpecialTokensTextJSON(value)) + return SpecialTokensText.fromJSON(value); else { void (value satisfies never); throw new Error(`Unknown value type: ${value}`); @@ -69,7 +72,7 @@ LlamaText.compare = function compare(a: LlamaText, b: LlamaText) { return true; }; -export class SpecialToken { +export class SpecialTokensText { public readonly value: string; public constructor(value: string) { @@ -84,56 +87,68 @@ export class SpecialToken { return tokenizer(this.value, trimLeadingSpace ? "trimLeadingSpace" : true); } - public toJSON(): LlamaTextSpecialTokenJSON { + public toJSON(): LlamaTextSpecialTokensTextJSON { return { - type: "specialToken", + type: "specialTokensText", value: this.value }; } - public static fromJSON(json: LlamaTextSpecialTokenJSON): SpecialToken { - if (json.builtin) - return new BuiltinSpecialToken(json.value as BuiltinSpecialTokenValue); - else - return new SpecialToken(json.value); + public static fromJSON(json: LlamaTextSpecialTokensTextJSON): SpecialTokensText { + if (SpecialTokensText.isSpecialTokensTextJSON(json)) + return new SpecialTokensText(json.value); + + throw new Error(`Invalid JSON for SpecialTokensText: ${JSON.stringify(json)}`); } - public static isSpecialTokenJSON(value: LlamaTextJSONValue): value is LlamaTextSpecialTokenJSON { - return value != null && typeof value === "object" && value.type === "specialToken"; + public static isSpecialTokensTextJSON(value: LlamaTextJSONValue): value is LlamaTextSpecialTokensTextJSON { + return value != null && typeof value === "object" && value.type === "specialTokensText"; } /** - * Wraps the value with a `SpecialToken` only if `shouldWrap` is true + * Wraps the value with a `SpecialTokensText` only if `shouldWrap` is true */ - public static wrapIf(shouldWrap: boolean, value: string): SpecialToken | string { + public static wrapIf(shouldWrap: boolean, value: string): SpecialTokensText | string { if (shouldWrap) - return new SpecialToken(value); + return new SpecialTokensText(value); else return value; } } export type BuiltinSpecialTokenValue = "BOS" | "EOS" | "NL"; -export class BuiltinSpecialToken extends SpecialToken { - public override readonly value: BuiltinSpecialTokenValue; +export class BuiltinSpecialToken { + public readonly value: BuiltinSpecialTokenValue; public constructor(value: BuiltinSpecialTokenValue) { - super(value); - this.value = value; } - public override tokenize(tokenizer: Tokenizer): Token[] { + public toString() { + return this.value; + } + + public tokenize(tokenizer: Tokenizer): Token[] { return tokenizer(this.value, "builtin"); } - public override toJSON(): LlamaTextSpecialTokenJSON { + public toJSON(): LlamaTextSpecialTokenJSON { return { type: "specialToken", - value: this.value, - builtin: true + value: this.value }; } + + public static fromJSON(json: LlamaTextSpecialTokenJSON): BuiltinSpecialToken { + if (BuiltinSpecialToken.isSpecialTokenJSON(json)) + return new BuiltinSpecialToken(json.value as BuiltinSpecialTokenValue); + + throw new Error(`Invalid JSON for SpecialToken: ${JSON.stringify(json)}`); + } + + public static isSpecialTokenJSON(value: LlamaTextJSONValue): value is LlamaTextSpecialTokenJSON { + return value != null && typeof value === "object" && value.type === "specialToken"; + } } export function isLlamaText(value: unknown): value is LlamaText { @@ -177,7 +192,9 @@ const LlamaTextPrototypeFunctions: Partial = { toString(this: LlamaText) { return this.values .map((value) => { - if (value instanceof SpecialToken) + if (value instanceof BuiltinSpecialToken) + return value.toString(); + else if (value instanceof SpecialTokensText) return value.toString(); else return value; @@ -192,7 +209,7 @@ const LlamaTextPrototypeFunctions: Partial = { if (value instanceof BuiltinSpecialToken) { res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer)); textToTokenize = ""; - } else if (value instanceof SpecialToken) { + } else if (value instanceof SpecialTokensText) { res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer, res.length > 0 || textToTokenize.length > 0)); textToTokenize = ""; } else @@ -205,8 +222,10 @@ const LlamaTextPrototypeFunctions: Partial = { }, toJSON(this: LlamaText) { return this.values.map((value) => { - if (value instanceof SpecialToken) - return {type: "specialToken", value: value.value} satisfies LlamaTextJSONValue; + if (value instanceof BuiltinSpecialToken) + return value.toJSON() satisfies LlamaTextJSONValue; + else if (value instanceof SpecialTokensText) + return value.toJSON() satisfies LlamaTextJSONValue; else return value satisfies LlamaTextJSONValue; }); @@ -223,13 +242,13 @@ const LlamaTextPrototypeFunctions: Partial = { if (firstValue instanceof BuiltinSpecialToken) break; - if (firstValue instanceof SpecialToken) { + if (firstValue instanceof SpecialTokensText) { const newValue = firstValue.value.trimStart(); if (newValue === "") { newValues.shift(); continue; } else if (newValue !== firstValue.value) { - newValues[0] = new SpecialToken(newValue); + newValues[0] = new SpecialTokensText(newValue); break; } @@ -260,13 +279,13 @@ const LlamaTextPrototypeFunctions: Partial = { if (lastValue instanceof BuiltinSpecialToken) break; - if (lastValue instanceof SpecialToken) { + if (lastValue instanceof SpecialTokensText) { const newValue = lastValue.value.trimEnd(); if (newValue === "") { newValues.pop(); continue; } else if (newValue !== lastValue.value) { - newValues[newValues.length - 1] = new SpecialToken(newValue); + newValues[newValues.length - 1] = new SpecialTokensText(newValue); break; } @@ -395,9 +414,9 @@ function createHistoryFromStringsAndValues[] ): Array { function addItemToRes(res: Array, item: LlamaTextInputValue) { - if (item === undefined || item === "" || (item instanceof SpecialToken && item.value === "")) + if (item === undefined || item === "" || (item instanceof SpecialTokensText && item.value === "")) return res; - else if (typeof item === "string" || item instanceof SpecialToken) + else if (typeof item === "string" || item instanceof SpecialTokensText || item instanceof BuiltinSpecialToken) return res.concat([item]); else if (isLlamaText(item)) return res.concat(item.values); @@ -406,7 +425,7 @@ function createHistoryFromStringsAndValues { if (isLlamaText(value)) return res.concat(value.values); - else if (value === "" || (value instanceof SpecialToken && value.value === "")) + else if (value === "" || (value instanceof SpecialTokensText && value.value === "")) return res; return res.concat([value]); @@ -434,8 +453,8 @@ function createHistoryFromStringsAndValues { .map((value) => { if (typeof value === "string") return [value]; - else if (value instanceof SpecialToken) + else if (value instanceof BuiltinSpecialToken) + return value.tokenize(tokenizer); + else if (value instanceof SpecialTokensText) return value.tokenize(tokenizer); return value satisfies never; From 128618b3aa12fb8508ae91ecb17fee155791e909 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 21:15:04 +0300 Subject: [PATCH 36/52] refactor: rename `BuiltinSpecialToken` to `SpecialToken` --- .vitepress/config.ts | 2 +- src/chatWrappers/ChatMLChatWrapper.ts | 6 ++-- src/chatWrappers/FalconChatWrapper.ts | 6 ++-- src/chatWrappers/FunctionaryChatWrapper.ts | 8 ++--- src/chatWrappers/GemmaChatWrapper.ts | 4 +-- src/chatWrappers/GeneralChatWrapper.ts | 6 ++-- src/chatWrappers/LlamaChatWrapper.ts | 8 ++--- .../generic/JinjaTemplateChatWrapper.ts | 10 +++--- .../generic/TemplateChatWrapper.ts | 4 +-- ...plateEquivalentToSpecializedChatWrapper.ts | 4 +-- src/index.ts | 4 +-- src/utils/LlamaText.ts | 34 +++++++++---------- src/utils/StopGenerationDetector.ts | 4 +-- 13 files changed, 50 insertions(+), 50 deletions(-) diff --git a/.vitepress/config.ts b/.vitepress/config.ts index 494b05f2..cd11d87b 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -326,7 +326,7 @@ function orderClasses(sidebar: typeof typedocSidebar) { items: [] }; (classes.items as DefaultTheme.SidebarItem[]).push(LlamaTextGroup); - const LlamaTextGroupItemsOrder = ["SpecialTokensText", "BuiltinSpecialToken"]; + const LlamaTextGroupItemsOrder = ["SpecialTokensText", "SpecialToken"]; groupItems( classes.items, diff --git a/src/chatWrappers/ChatMLChatWrapper.ts b/src/chatWrappers/ChatMLChatWrapper.ts index 4af57816..0fdfe87e 100644 --- a/src/chatWrappers/ChatMLChatWrapper.ts +++ b/src/chatWrappers/ChatMLChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://github.com/openai/openai-python/blob/120d225b91a8453e15240a49fb1c6794d8119326/chatml.md export class ChatMLChatWrapper extends ChatWrapper { @@ -70,7 +70,7 @@ export class ChatMLChatWrapper extends ChatWrapper { flush(); const contextText = LlamaText( - new BuiltinSpecialToken("BOS"), + new SpecialToken("BOS"), resultItems.map(({system, user, model}, index) => { const isLastItem = index === resultItems.length - 1; @@ -108,7 +108,7 @@ export class ChatMLChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<|im_end|>")), LlamaText("<|im_end|>") ] diff --git a/src/chatWrappers/FalconChatWrapper.ts b/src/chatWrappers/FalconChatWrapper.ts index 9cef38ae..a1fe03ed 100644 --- a/src/chatWrappers/FalconChatWrapper.ts +++ b/src/chatWrappers/FalconChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {LlamaText, BuiltinSpecialToken, SpecialTokensText} from "../utils/LlamaText.js"; +import {LlamaText, SpecialToken, SpecialTokensText} from "../utils/LlamaText.js"; export class FalconChatWrapper extends ChatWrapper { public readonly wrapperName: string = "Falcon"; @@ -94,7 +94,7 @@ export class FalconChatWrapper extends ChatWrapper { flush(); const contextText = LlamaText( - new BuiltinSpecialToken("BOS"), + new SpecialToken("BOS"), resultItems.map(({system, user, model}, index) => { const isFirstItem = index === 0; const isLastItem = index === resultItems.length - 1; @@ -134,7 +134,7 @@ export class FalconChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(`\n${this._userMessageTitle}:`), LlamaText(`\n${this._modelResponseTitle}:`), diff --git a/src/chatWrappers/FunctionaryChatWrapper.ts b/src/chatWrappers/FunctionaryChatWrapper.ts index da334380..129f29e0 100644 --- a/src/chatWrappers/FunctionaryChatWrapper.ts +++ b/src/chatWrappers/FunctionaryChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions, isChatModelResponseFunctionCall} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; import {getTypeScriptTypeStringForGbnfJsonSchema} from "../utils/getTypeScriptTypeStringForGbnfJsonSchema.js"; // source: https://github.com/MeetKai/functionary/blob/main/tests/prompt_test_v2.txt @@ -41,7 +41,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { }); const contextText = LlamaText( - new BuiltinSpecialToken("BOS"), + new SpecialToken("BOS"), historyWithFunctions.map((item, index) => { const isFirstItem = index === 0; const isLastItem = index === historyWithFunctions.length - 1; @@ -150,7 +150,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), LlamaText("<|stop|>"), @@ -170,7 +170,7 @@ export class FunctionaryChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("<|stop|>")), LlamaText(" <|stop|>"), diff --git a/src/chatWrappers/GemmaChatWrapper.ts b/src/chatWrappers/GemmaChatWrapper.ts index 7d99bd06..edff479e 100644 --- a/src/chatWrappers/GemmaChatWrapper.ts +++ b/src/chatWrappers/GemmaChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://ai.google.dev/gemma/docs/formatting // source: https://www.promptingguide.ai/models/gemma @@ -109,7 +109,7 @@ export class GemmaChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("\n")), LlamaText("") ] diff --git a/src/chatWrappers/GeneralChatWrapper.ts b/src/chatWrappers/GeneralChatWrapper.ts index 7aef46a6..3c591951 100644 --- a/src/chatWrappers/GeneralChatWrapper.ts +++ b/src/chatWrappers/GeneralChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; export class GeneralChatWrapper extends ChatWrapper { public readonly wrapperName: string = "General"; @@ -95,7 +95,7 @@ export class GeneralChatWrapper extends ChatWrapper { flush(); const contextText = LlamaText( - new BuiltinSpecialToken("BOS"), + new SpecialToken("BOS"), resultItems.map(({system, user, model}, index) => { const isFirstItem = index === 0; const isLastItem = index === resultItems.length - 1; @@ -135,7 +135,7 @@ export class GeneralChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(new SpecialTokensText("")), LlamaText(""), diff --git a/src/chatWrappers/LlamaChatWrapper.ts b/src/chatWrappers/LlamaChatWrapper.ts index 55574477..71b17b5c 100644 --- a/src/chatWrappers/LlamaChatWrapper.ts +++ b/src/chatWrappers/LlamaChatWrapper.ts @@ -1,6 +1,6 @@ import {ChatWrapper} from "../ChatWrapper.js"; import {ChatHistoryItem, ChatModelFunctions} from "../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../utils/LlamaText.js"; // source: https://huggingface.co/blog/llama2#how-to-prompt-llama-2 export class LlamaChatWrapper extends ChatWrapper { @@ -83,7 +83,7 @@ export class LlamaChatWrapper extends ChatWrapper { const isLastItem = index === resultItems.length - 1; return LlamaText([ - new BuiltinSpecialToken("BOS"), + new SpecialToken("BOS"), (system.length === 0 && user.length === 0) ? LlamaText([]) : LlamaText([ @@ -104,7 +104,7 @@ export class LlamaChatWrapper extends ChatWrapper { : "", isLastItem ? LlamaText([]) - : new BuiltinSpecialToken("EOS") + : new SpecialToken("EOS") ]); }) ); @@ -112,7 +112,7 @@ export class LlamaChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText("") ] }; diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts index 1902eb1f..e6313ac8 100644 --- a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -1,7 +1,7 @@ import {Template} from "@huggingface/jinja"; import {splitText} from "lifecycle-utils"; import {ChatHistoryItem, ChatModelFunctions, ChatUserMessage} from "../../types.js"; -import {BuiltinSpecialToken, LlamaText, SpecialTokensText} from "../../utils/LlamaText.js"; +import {SpecialToken, LlamaText, SpecialTokensText} from "../../utils/LlamaText.js"; import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; @@ -188,7 +188,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { user: this.userRoleName, model: this.modelRoleName } as const; - const idToContent = new Map(); + const idToContent = new Map(); const modelMessageIds = new Set(); const messageIds = new Set(); @@ -209,8 +209,8 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { const bosTokenId = idsGenerator.generateId(); const eosTokenId = idsGenerator.generateId(); - idToContent.set(bosTokenId, new BuiltinSpecialToken("BOS")); - idToContent.set(eosTokenId, new BuiltinSpecialToken("EOS")); + idToContent.set(bosTokenId, new SpecialToken("BOS")); + idToContent.set(eosTokenId, new SpecialToken("EOS")); const renderJinjaText = () => { try { @@ -306,7 +306,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), ...( stopGenerationJinjaParts.length === 0 ? [] diff --git a/src/chatWrappers/generic/TemplateChatWrapper.ts b/src/chatWrappers/generic/TemplateChatWrapper.ts index f63a87c0..fa6e6996 100644 --- a/src/chatWrappers/generic/TemplateChatWrapper.ts +++ b/src/chatWrappers/generic/TemplateChatWrapper.ts @@ -1,5 +1,5 @@ import {ChatHistoryItem, ChatModelFunctions} from "../../types.js"; -import {BuiltinSpecialToken, LlamaText, LlamaTextValue, SpecialTokensText} from "../../utils/LlamaText.js"; +import {SpecialToken, LlamaText, LlamaTextValue, SpecialTokensText} from "../../utils/LlamaText.js"; import {ChatWrapper, ChatWrapperSettings} from "../../ChatWrapper.js"; import {parseTextTemplate} from "../../utils/parseTextTemplate.js"; import {ChatHistoryFunctionCallMessageTemplate, parseFunctionCallMessageTemplate} from "./utils/chatHistoryFunctionCallMessageTemplate.js"; @@ -220,7 +220,7 @@ export class TemplateChatWrapper extends ChatWrapper { return { contextText, stopGenerationTriggers: [ - LlamaText(new BuiltinSpecialToken("EOS")), + LlamaText(new SpecialToken("EOS")), LlamaText(this._parsedChatTemplate.completionSuffix), LlamaText(new SpecialTokensText(this._parsedChatTemplate.completionSuffix)) ] diff --git a/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts index 6e02829a..824fb97b 100644 --- a/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts +++ b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts @@ -1,7 +1,7 @@ import {ChatWrapper} from "../../ChatWrapper.js"; import {ChatHistoryItem, ChatModelResponse, ChatUserMessage, Tokenizer} from "../../types.js"; import {JinjaTemplateChatWrapper, JinjaTemplateChatWrapperOptions} from "../generic/JinjaTemplateChatWrapper.js"; -import {BuiltinSpecialToken, LlamaText} from "../../utils/LlamaText.js"; +import {SpecialToken, LlamaText} from "../../utils/LlamaText.js"; import {compareTokens} from "../../utils/compareTokens.js"; import {StopGenerationDetector} from "../../utils/StopGenerationDetector.js"; @@ -242,7 +242,7 @@ function removeLeadingBos(llamaText: LlamaText) { const firstValue = llamaText.values[0]; - if (firstValue instanceof BuiltinSpecialToken && firstValue.value === "BOS") + if (firstValue instanceof SpecialToken && firstValue.value === "BOS") return LlamaText(llamaText.values.slice(1)); return llamaText; diff --git a/src/index.ts b/src/index.ts index 3412fb42..889d9e35 100644 --- a/src/index.ts +++ b/src/index.ts @@ -49,7 +49,7 @@ import { type ResolveChatWrapperOptions } from "./chatWrappers/utils/resolveChatWrapper.js"; import { - LlamaText, SpecialTokensText, BuiltinSpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, + LlamaText, SpecialTokensText, SpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, type LlamaTextJSONValue, type LlamaTextSpecialTokensTextJSON, type LlamaTextSpecialTokenJSON } from "./utils/LlamaText.js"; import {appendUserMessageToChatHistory} from "./utils/appendUserMessageToChatHistory.js"; @@ -153,7 +153,7 @@ export { type TemplateChatWrapperTypeName, LlamaText, SpecialTokensText, - BuiltinSpecialToken, + SpecialToken, isLlamaText, tokenizeText, type LlamaTextJSON, diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index 070b23d9..ccaffd39 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -28,7 +28,7 @@ export type LlamaText = { includes(value: LlamaText): boolean }; -export type LlamaTextValue = string | SpecialTokensText | BuiltinSpecialToken; +export type LlamaTextValue = string | SpecialTokensText | SpecialToken; export type LlamaTextJSON = Array; export type LlamaTextJSONValue = string | LlamaTextSpecialTokensTextJSON | LlamaTextSpecialTokenJSON; @@ -36,8 +36,8 @@ export type LlamaTextSpecialTokensTextJSON = {type: "specialTokensText", value: export type LlamaTextSpecialTokenJSON = {type: "specialToken", value: string}; export const LlamaText: LlamaTextClass = function LlamaText( - strings: TemplateStringsArray | string | string[] | SpecialTokensText | BuiltinSpecialToken | LlamaText | LlamaText[], - ...values: (SpecialTokensText | BuiltinSpecialToken | string | string[] | number | boolean | LlamaText | LlamaText[])[] + strings: TemplateStringsArray | string | string[] | SpecialTokensText | SpecialToken | LlamaText | LlamaText[], + ...values: (SpecialTokensText | SpecialToken | string | string[] | number | boolean | LlamaText | LlamaText[])[] ) { return createLlamaText(createHistoryFromStringsAndValues(strings, values)); } as LlamaTextClass; @@ -46,8 +46,8 @@ LlamaText.fromJSON = function fromJSON(json: LlamaTextJSON) { json.map((value) => { if (typeof value === "string") return value; - else if (BuiltinSpecialToken.isSpecialTokenJSON(value)) - return BuiltinSpecialToken.fromJSON(value); + else if (SpecialToken.isSpecialTokenJSON(value)) + return SpecialToken.fromJSON(value); else if (SpecialTokensText.isSpecialTokensTextJSON(value)) return SpecialTokensText.fromJSON(value); else { @@ -117,7 +117,7 @@ export class SpecialTokensText { } export type BuiltinSpecialTokenValue = "BOS" | "EOS" | "NL"; -export class BuiltinSpecialToken { +export class SpecialToken { public readonly value: BuiltinSpecialTokenValue; public constructor(value: BuiltinSpecialTokenValue) { @@ -139,9 +139,9 @@ export class BuiltinSpecialToken { }; } - public static fromJSON(json: LlamaTextSpecialTokenJSON): BuiltinSpecialToken { - if (BuiltinSpecialToken.isSpecialTokenJSON(json)) - return new BuiltinSpecialToken(json.value as BuiltinSpecialTokenValue); + public static fromJSON(json: LlamaTextSpecialTokenJSON): SpecialToken { + if (SpecialToken.isSpecialTokenJSON(json)) + return new SpecialToken(json.value as BuiltinSpecialTokenValue); throw new Error(`Invalid JSON for SpecialToken: ${JSON.stringify(json)}`); } @@ -192,7 +192,7 @@ const LlamaTextPrototypeFunctions: Partial = { toString(this: LlamaText) { return this.values .map((value) => { - if (value instanceof BuiltinSpecialToken) + if (value instanceof SpecialToken) return value.toString(); else if (value instanceof SpecialTokensText) return value.toString(); @@ -206,7 +206,7 @@ const LlamaTextPrototypeFunctions: Partial = { const res: Token[] = []; for (const value of this.values) { - if (value instanceof BuiltinSpecialToken) { + if (value instanceof SpecialToken) { res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer)); textToTokenize = ""; } else if (value instanceof SpecialTokensText) { @@ -222,7 +222,7 @@ const LlamaTextPrototypeFunctions: Partial = { }, toJSON(this: LlamaText) { return this.values.map((value) => { - if (value instanceof BuiltinSpecialToken) + if (value instanceof SpecialToken) return value.toJSON() satisfies LlamaTextJSONValue; else if (value instanceof SpecialTokensText) return value.toJSON() satisfies LlamaTextJSONValue; @@ -239,7 +239,7 @@ const LlamaTextPrototypeFunctions: Partial = { while (newValues.length > 0) { const firstValue = newValues[0]; - if (firstValue instanceof BuiltinSpecialToken) + if (firstValue instanceof SpecialToken) break; if (firstValue instanceof SpecialTokensText) { @@ -276,7 +276,7 @@ const LlamaTextPrototypeFunctions: Partial = { while (newValues.length > 0) { const lastValue = newValues[newValues.length - 1]; - if (lastValue instanceof BuiltinSpecialToken) + if (lastValue instanceof SpecialToken) break; if (lastValue instanceof SpecialTokensText) { @@ -416,7 +416,7 @@ function createHistoryFromStringsAndValues, item: LlamaTextInputValue) { if (item === undefined || item === "" || (item instanceof SpecialTokensText && item.value === "")) return res; - else if (typeof item === "string" || item instanceof SpecialTokensText || item instanceof BuiltinSpecialToken) + else if (typeof item === "string" || item instanceof SpecialTokensText || item instanceof SpecialToken) return res.concat([item]); else if (isLlamaText(item)) return res.concat(item.values); @@ -445,7 +445,7 @@ function createHistoryFromStringsAndValues { .map((value) => { if (typeof value === "string") return [value]; - else if (value instanceof BuiltinSpecialToken) + else if (value instanceof SpecialToken) return value.tokenize(tokenizer); else if (value instanceof SpecialTokensText) return value.tokenize(tokenizer); From d754b4bc7c01b199c58955ea6136b5ccc35825fe Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 22:40:49 +0300 Subject: [PATCH 37/52] feat: improve control of leading space in tokenization, trim leading whitespace in responses when using `JinjaTemplateChatWrapper` by default --- .../generic/JinjaTemplateChatWrapper.ts | 53 +++++++++++- ...plateEquivalentToSpecializedChatWrapper.ts | 33 ++++++- src/cli/commands/ChatCommand.ts | 74 +++------------- src/cli/commands/CompleteCommand.ts | 62 +++----------- src/cli/commands/InfillCommand.ts | 62 +++----------- src/cli/utils/printCommonInfoLines.ts | 85 +++++++++++++++++++ src/evaluator/LlamaChat/LlamaChat.ts | 8 +- src/evaluator/LlamaModel.ts | 45 ++++++---- src/types.ts | 2 +- src/utils/LlamaText.ts | 2 +- src/utils/StopGenerationDetector.ts | 2 +- 11 files changed, 237 insertions(+), 191 deletions(-) create mode 100644 src/cli/utils/printCommonInfoLines.ts diff --git a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts index e6313ac8..a389aa2b 100644 --- a/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts +++ b/src/chatWrappers/generic/JinjaTemplateChatWrapper.ts @@ -10,9 +10,39 @@ export type JinjaTemplateChatWrapperOptions = { modelRoleName?: string, userRoleName?: string, systemRoleName?: string, - convertUnsupportedSystemMessagesToUserMessages?: boolean | "auto" | ConvertMessageFormatOptions, + + /** + * Some Jinja templates may not support system messages, and in such cases, + * it'll be detected and system messages can be converted to user messages. + * + * You can specify the format of the converted user message. + * - **"auto"**: Convert system messages to user messages only if the template does not support system messages. + * - **`true`**: Always convert system messages to user messages. + * - **`false`**: Never convert system messages to user messages. + * May throw an error if some system messages don't appear in the template. + * - **`{use: "ifNeeded", format: "..."}`**: Convert system messages to user messages only if the template does not support system + * messages with the specified format. + * - **`{use: "always", format: "..."}`**: Always convert system messages to user messages with the specified format. + * + * Defaults to `"auto"`. + */ + convertUnsupportedSystemMessagesToUserMessages?: "auto" | boolean | ConvertMessageFormatOptions, functionCallMessageTemplate?: ChatHistoryFunctionCallMessageTemplate, - joinAdjacentMessagesOfTheSameType?: boolean + + /** + * Whether to join adjacent messages of the same type. + * Some Jinja templates may throw an error if this is not set to `true`. + * + * Defaults to `true`. + */ + joinAdjacentMessagesOfTheSameType?: boolean, + + /** + * Whether to trim leading whitespace in responses. + * + * Defaults to `true`. + */ + trimLeadingWhitespaceInResponses?: boolean }; type ConvertMessageFormatOptions = { @@ -43,6 +73,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { public readonly systemRoleName: string; public readonly convertUnsupportedSystemMessagesToUserMessages?: ConvertMessageFormatOptions; public readonly joinAdjacentMessagesOfTheSameType: boolean; + public readonly trimLeadingWhitespaceInResponses: boolean; /** @internal */ private readonly _jinjaTemplate: Template; @@ -53,7 +84,8 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { systemRoleName = "system", convertUnsupportedSystemMessagesToUserMessages = defaultConvertUnsupportedSystemMessagesToUserMessagesFormat, functionCallMessageTemplate, - joinAdjacentMessagesOfTheSameType = true + joinAdjacentMessagesOfTheSameType = true, + trimLeadingWhitespaceInResponses = true }: JinjaTemplateChatWrapperOptions) { super(); @@ -67,6 +99,7 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { this.convertUnsupportedSystemMessagesToUserMessages = resolveConvertUnsupportedSystemMessagesToUserMessagesOption(convertUnsupportedSystemMessagesToUserMessages); this.joinAdjacentMessagesOfTheSameType = joinAdjacentMessagesOfTheSameType; + this.trimLeadingWhitespaceInResponses = trimLeadingWhitespaceInResponses; this.settings = { ...super.settings, @@ -119,7 +152,8 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { convertSystemMessagesToUserMessagesFormat?: string }): { contextText: LlamaText, - stopGenerationTriggers: LlamaText[] + stopGenerationTriggers: LlamaText[], + ignoreStartText?: LlamaText[] } { const transformedHistory = convertSystemMessagesToUserMessagesFormat == null ? history @@ -305,6 +339,17 @@ export class JinjaTemplateChatWrapper extends ChatWrapper { return { contextText, + ignoreStartText: !this.trimLeadingWhitespaceInResponses + ? [] + : [ + // ignore up to 4 leading spaces + ...Array(4).fill(0) + .map((_, index) => LlamaText(" ".repeat(index + 1))), + LlamaText("\t"), + LlamaText("\t\t"), + LlamaText("\t "), + LlamaText(" \t") + ], stopGenerationTriggers: [ LlamaText(new SpecialToken("EOS")), ...( diff --git a/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts index 824fb97b..34fc0675 100644 --- a/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts +++ b/src/chatWrappers/utils/isJinjaTemplateEquivalentToSpecializedChatWrapper.ts @@ -19,11 +19,24 @@ export function isJinjaTemplateEquivalentToSpecializedChatWrapper( ...jinjaTemplateWrapperOptions, convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages ? false - : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages + : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, + trimLeadingWhitespaceInResponses: false }); if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, testChatHistories, tokenizer)) return true; + + + const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ + ...jinjaTemplateWrapperOptions, + convertUnsupportedSystemMessagesToUserMessages: canTestMultipleConvertSystemMessagesToUserMessages + ? false + : jinjaTemplateWrapperOptions.convertUnsupportedSystemMessagesToUserMessages, + trimLeadingWhitespaceInResponses: true + }); + + if (checkEquivalence(jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, testChatHistories, tokenizer)) + return true; } catch (err) { // Do nothing } @@ -38,7 +51,8 @@ export function isJinjaTemplateEquivalentToSpecializedChatWrapper( convertUnsupportedSystemMessagesToUserMessages: { use: "always", format: convertSystemMessagesToUserMessagesTemplate - } + }, + trimLeadingWhitespaceInResponses: false }); const transformedTestChatHistories = testChatHistories @@ -68,6 +82,21 @@ export function isJinjaTemplateEquivalentToSpecializedChatWrapper( if (checkEquivalence(jinjaChatWrapper, specializedChatWrapper, transformedTestChatHistories, tokenizer)) return true; + + + const jinjaChatWrapperWithLeadingWhitespaceTrimming = new JinjaTemplateChatWrapper({ + ...jinjaTemplateWrapperOptions, + convertUnsupportedSystemMessagesToUserMessages: { + use: "always", + format: convertSystemMessagesToUserMessagesTemplate + }, + trimLeadingWhitespaceInResponses: true + }); + + if (checkEquivalence( + jinjaChatWrapperWithLeadingWhitespaceTrimming, specializedChatWrapper, transformedTestChatHistories, tokenizer + )) + return true; } catch (err) { // Do nothing } diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index 5a789f6b..08a9b241 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -17,11 +17,11 @@ import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; -import {getPrettyBuildGpuName} from "../../bindings/consts.js"; import { resolveChatWrapper, SpecializedChatWrapperTypeName, specializedChatWrapperTypeNames } from "../../chatWrappers/utils/resolveChatWrapper.js"; import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; +import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; type ChatCommand = { model: string, @@ -112,6 +112,7 @@ export const ChatCommand: CommandModule = { alias: "c", type: "number", description: "Context size to use for the model context", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -173,6 +174,7 @@ export const ChatCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -281,6 +283,9 @@ async function RunChat({ lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings }: ChatCommand) { + if (contextSize === -1) contextSize = undefined; + if (gpuLayers === -1) gpuLayers = undefined; + if (debug) console.info(`${chalk.yellow("Log level:")} debug`); @@ -362,11 +367,9 @@ async function RunChat({ : grammarArg !== "text" ? await LlamaGrammar.getFor(llama, grammarArg) : undefined; - const bos = model.tokens.bosString; // bos = beginning of sequence - const eos = model.tokens.bosString; // eos = end of sequence const chatWrapper = resolveChatWrapper({ type: wrapper, - bosString: bos, + bosString: model.tokens.bosString, filename: model.filename, fileInfo: model.fileInfo, tokenizer: model.tokenize @@ -390,62 +393,13 @@ async function RunChat({ } const padTitle = "Context".length + 1; - if (llama.gpu !== false) { - printInfoLine({ - title: "GPU", - padTitle: padTitle, - info: [{ - title: "Type", - value: getPrettyBuildGpuName(llama.gpu) - }, { - title: "VRAM", - value: bytes(llama.getVramState().total) - }, { - title: "Name", - value: llama.getGpuDeviceNames().join(", ") - }, { - title: "GPU layers", - value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ - chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) - }` - }] - }); - } - printInfoLine({ - title: "Model", - padTitle: padTitle, - info: [{ - title: "Type", - value: model.typeDescription - }, { - title: "Size", - value: bytes(model.size) - }, { - title: "BOS", - value: String(bos) - }, { - title: "EOS", - value: String(eos) - }, { - title: "Train context size", - value: String(model.trainContextSize) - }] - }); - printInfoLine({ - title: "Context", - padTitle: padTitle, - info: [{ - title: "Size", - value: String(context.contextSize) - }, { - show: logBatchSize, - title: "Batch size", - value: bytes(context.batchSize) - }, { - show: meter, - title: "Token meter", - value: "enabled" - }] + printCommonInfoLines({ + context, + minTitleLength: padTitle, + printBos: true, + printEos: true, + logBatchSize, + tokenMeterEnabled: meter }); printInfoLine({ title: "Chat", diff --git a/src/cli/commands/CompleteCommand.ts b/src/cli/commands/CompleteCommand.ts index cddd50de..0d1ff92a 100644 --- a/src/cli/commands/CompleteCommand.ts +++ b/src/cli/commands/CompleteCommand.ts @@ -11,7 +11,7 @@ import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; -import {getPrettyBuildGpuName} from "../../bindings/consts.js"; +import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; type CompleteCommand = { model: string, @@ -70,6 +70,7 @@ export const CompleteCommand: CommandModule = { alias: "c", type: "number", description: "Context size to use for the model context", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -117,6 +118,7 @@ export const CompleteCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -210,6 +212,9 @@ async function RunCompletion({ lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, debug, meter, printTimings }: CompleteCommand) { + if (contextSize === -1) contextSize = undefined; + if (gpuLayers === -1) gpuLayers = undefined; + if (debug) console.info(`${chalk.yellow("Log level:")} debug`); @@ -284,56 +289,11 @@ async function RunCompletion({ await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing const padTitle = "Complete".length + 1; - if (llama.gpu !== false) { - printInfoLine({ - title: "GPU", - padTitle: padTitle, - info: [{ - title: "Type", - value: getPrettyBuildGpuName(llama.gpu) - }, { - title: "VRAM", - value: bytes(llama.getVramState().total) - }, { - title: "Name", - value: llama.getGpuDeviceNames().join(", ") - }, { - title: "GPU layers", - value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ - chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) - }` - }] - }); - } - printInfoLine({ - title: "Model", - padTitle: padTitle, - info: [{ - title: "Type", - value: model.typeDescription - }, { - title: "Size", - value: bytes(model.size) - }, { - title: "Train context size", - value: String(model.trainContextSize) - }] - }); - printInfoLine({ - title: "Context", - padTitle: padTitle, - info: [{ - title: "Size", - value: String(context.contextSize) - }, { - show: logBatchSize, - title: "Batch size", - value: bytes(context.batchSize) - }, { - show: meter, - title: "Token meter", - value: "enabled" - }] + printCommonInfoLines({ + context, + minTitleLength: padTitle, + logBatchSize, + tokenMeterEnabled: meter }); printInfoLine({ title: "Complete", diff --git a/src/cli/commands/InfillCommand.ts b/src/cli/commands/InfillCommand.ts index acd2203f..aefbb882 100644 --- a/src/cli/commands/InfillCommand.ts +++ b/src/cli/commands/InfillCommand.ts @@ -11,7 +11,7 @@ import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; -import {getPrettyBuildGpuName} from "../../bindings/consts.js"; +import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; type InfillCommand = { model: string, @@ -82,6 +82,7 @@ export const InfillCommand: CommandModule = { alias: "c", type: "number", description: "Context size to use for the model context", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -129,6 +130,7 @@ export const InfillCommand: CommandModule = { alias: "gl", type: "number", description: "number of layers to store in VRAM", + default: -1, defaultDescription: "Automatically determined based on the available VRAM", group: "Optional:" }) @@ -222,6 +224,9 @@ async function RunInfill({ lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, debug, meter, printTimings }: InfillCommand) { + if (contextSize === -1) contextSize = undefined; + if (gpuLayers === -1) gpuLayers = undefined; + if (debug) console.info(`${chalk.yellow("Log level:")} debug`); @@ -310,56 +315,11 @@ async function RunInfill({ await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing const padTitle = "Context".length + 1; - if (llama.gpu !== false) { - printInfoLine({ - title: "GPU", - padTitle: padTitle, - info: [{ - title: "Type", - value: getPrettyBuildGpuName(llama.gpu) - }, { - title: "VRAM", - value: bytes(llama.getVramState().total) - }, { - title: "Name", - value: llama.getGpuDeviceNames().join(", ") - }, { - title: "GPU layers", - value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ - chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) - }` - }] - }); - } - printInfoLine({ - title: "Model", - padTitle: padTitle, - info: [{ - title: "Type", - value: model.typeDescription - }, { - title: "Size", - value: bytes(model.size) - }, { - title: "Train context size", - value: String(model.trainContextSize) - }] - }); - printInfoLine({ - title: "Context", - padTitle: padTitle, - info: [{ - title: "Size", - value: String(context.contextSize) - }, { - show: logBatchSize, - title: "Batch size", - value: bytes(context.batchSize) - }, { - show: meter, - title: "Token meter", - value: "enabled" - }] + printCommonInfoLines({ + context, + minTitleLength: padTitle, + logBatchSize, + tokenMeterEnabled: meter }); printInfoLine({ title: "Infill", diff --git a/src/cli/utils/printCommonInfoLines.ts b/src/cli/utils/printCommonInfoLines.ts new file mode 100644 index 00000000..a6e9b54e --- /dev/null +++ b/src/cli/utils/printCommonInfoLines.ts @@ -0,0 +1,85 @@ +import bytes from "bytes"; +import chalk from "chalk"; +import {getPrettyBuildGpuName} from "../../bindings/consts.js"; +import {LlamaContext} from "../../evaluator/LlamaContext/LlamaContext.js"; +import {printInfoLine} from "./printInfoLine.js"; + +export function printCommonInfoLines({ + context, + minTitleLength = 0, + logBatchSize = false, + tokenMeterEnabled = false, + printBos = false, + printEos = false +}: { + context: LlamaContext, + minTitleLength?: number, + logBatchSize?: boolean, + tokenMeterEnabled?: boolean, + printBos?: boolean, + printEos?: boolean +}) { + const llama = context._llama; + const model = context.model; + const padTitle = Math.max(minTitleLength, "Context".length + 1); + + if (llama.gpu !== false) { + printInfoLine({ + title: "GPU", + padTitle: padTitle, + info: [{ + title: "Type", + value: getPrettyBuildGpuName(llama.gpu) + }, { + title: "VRAM", + value: bytes(llama.getVramState().total) + }, { + title: "Name", + value: llama.getGpuDeviceNames().join(", ") + }, { + title: "GPU layers", + value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ + chalk.dim(`(${Math.floor((model.gpuLayers / model.fileInsights.totalLayers) * 100)}%)`) + }` + }] + }); + } + printInfoLine({ + title: "Model", + padTitle: padTitle, + info: [{ + title: "Type", + value: model.typeDescription + }, { + title: "Size", + value: bytes(model.size) + }, { + show: printBos, + title: "BOS", + value: () => String(model.tokens.bosString) + }, { + show: printEos, + title: "EOS", + value: () => String(model.tokens.eosString) + }, { + title: "Train context size", + value: String(model.trainContextSize) + }] + }); + printInfoLine({ + title: "Context", + padTitle: padTitle, + info: [{ + title: "Size", + value: String(context.contextSize) + }, { + show: logBatchSize, + title: "Batch size", + value: bytes(context.batchSize) + }, { + show: tokenMeterEnabled, + title: "Token meter", + value: "enabled" + }] + }); +} diff --git a/src/evaluator/LlamaChat/LlamaChat.ts b/src/evaluator/LlamaChat/LlamaChat.ts index 67f005c3..6229ee59 100644 --- a/src/evaluator/LlamaChat/LlamaChat.ts +++ b/src/evaluator/LlamaChat/LlamaChat.ts @@ -404,7 +404,7 @@ export class LlamaChat { ignoredStartTextTokens = mostExhaustiveTriggeredStop.stopTrigger .map((stopTrigger) => { if (typeof stopTrigger === "string") - return model.tokenize(stopTrigger); + return model.tokenize(stopTrigger, false, "trimLeadingSpace"); else return [stopTrigger]; }) @@ -413,7 +413,7 @@ export class LlamaChat { const newPendingTokens = mostExhaustiveTriggeredStop.remainingGenerations .map((generation) => { if (typeof generation === "string") - return model.tokenize(generation); + return model.tokenize(generation, false, "trimLeadingSpace"); else return generation; }) @@ -640,7 +640,9 @@ export class LlamaChat { ? firstRemainingGenerationAfterStop : model.detokenize(firstRemainingGenerationAfterStop); - functionCallTokens.push(...model.tokenize(this._chatWrapper.settings.functions.call.prefix + remainingTextAfterStop)); + functionCallTokens.push(...model.tokenize( + this._chatWrapper.settings.functions.call.prefix + remainingTextAfterStop, false, "trimLeadingSpace" + )); for (const functionCallToken of functionCallTokens) context._acceptTokenOnGrammarEvaluationState(functionsEvaluationState, functionCallToken); diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index ff44ea6f..0437a50f 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -218,9 +218,9 @@ export class LlamaModel { * For example, `` will be tokenized to the BOS token if `specialTokens` is set to `true`, * otherwise it will be tokenized to tokens that corresponds to the plaintext `` string. */ - public tokenize(text: string, specialTokens?: boolean | "trimLeadingSpace"): Token[]; + public tokenize(text: string, specialTokens?: boolean, options?: "trimLeadingSpace"): Token[]; public tokenize(text: BuiltinSpecialTokenValue, specialTokens: "builtin"): Token[]; - public tokenize(text: string, specialTokens: boolean | "builtin" | "trimLeadingSpace" = false): Token[] { + public tokenize(text: string, specialTokens: boolean | "builtin" = false, options?: "trimLeadingSpace"): Token[] { this._ensureNotDisposed(); if (text === "") @@ -239,24 +239,35 @@ export class LlamaModel { throw new Error(`Unknown builtin special token: ${builtinToken}`); } - if (specialTokens === "trimLeadingSpace") { - specialTokens = true; - - const [workaroundToken, workaroundTokenString] = (this.tokens.bos != null && this.tokens.bosString != null) - ? [this.tokens.bos, this.tokens.bosString] - : (this.tokens.eos != null && this.tokens.eosString != null) - ? [this.tokens.eos, this.tokens.eosString] - : (this.tokens.nl != null && this.tokens.nlString != null) - ? [this.tokens.nl, this.tokens.nlString] - : [null, null]; + if (options === "trimLeadingSpace") { + if (specialTokens) { + const [workaroundToken, workaroundTokenString] = (this.tokens.bos != null && this.tokens.bosString != null) + ? [this.tokens.bos, this.tokens.bosString] + : (this.tokens.eos != null && this.tokens.eosString != null) + ? [this.tokens.eos, this.tokens.eosString] + : (this.tokens.nl != null && this.tokens.nlString != null) + ? [this.tokens.nl, this.tokens.nlString] + : [null, null]; + + if (workaroundToken != null && workaroundTokenString != null) { + const tokens = Array.from(this._model.tokenize(workaroundTokenString + text, true)) as Token[]; + const workaroundTokenIndex = tokens.indexOf(workaroundToken); + + // only use the tokenized output if it can be corrected, otherwise fallback to the default tokenization + if (workaroundTokenIndex >= 0 && workaroundTokenIndex <= 1) { + tokens.splice(0, workaroundTokenIndex + 1); + return tokens; + } + } + } else { + const workaroundTokens = Array.from(this._model.tokenize("\n", false)) as Token[]; + const workaroundTokensString = "\n"; - if (workaroundToken != null && workaroundTokenString != null) { - const tokens = Array.from(this._model.tokenize(workaroundTokenString + text, true)) as Token[]; - const workaroundTokenIndex = tokens.indexOf(workaroundToken); + const tokens = Array.from(this._model.tokenize(workaroundTokensString + text, false)) as Token[]; // only use the tokenized output if it can be corrected, otherwise fallback to the default tokenization - if (workaroundTokenIndex >= 0 && workaroundTokenIndex <= 1) { - tokens.splice(0, workaroundTokenIndex + 1); + if (workaroundTokens.length > 0 && workaroundTokens.every((token, index) => tokens[index] === token)) { + tokens.splice(0, workaroundTokens.length); return tokens; } } diff --git a/src/types.ts b/src/types.ts index 13fd619c..ff8dc6f2 100644 --- a/src/types.ts +++ b/src/types.ts @@ -6,7 +6,7 @@ export type Token = number & { }; export type Tokenizer = { - tokenize(text: string, specialTokens?: boolean | "trimLeadingSpace"): Token[], + tokenize(text: string, specialTokens?: boolean, options?: "trimLeadingSpace"): Token[], tokenize(text: BuiltinSpecialTokenValue, specialTokens: "builtin"): Token[] }["tokenize"]; diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index ccaffd39..c53ef21a 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -84,7 +84,7 @@ export class SpecialTokensText { } public tokenize(tokenizer: Tokenizer, trimLeadingSpace: boolean = false): Token[] { - return tokenizer(this.value, trimLeadingSpace ? "trimLeadingSpace" : true); + return tokenizer(this.value, true, trimLeadingSpace ? "trimLeadingSpace" : undefined); } public toJSON(): LlamaTextSpecialTokensTextJSON { diff --git a/src/utils/StopGenerationDetector.ts b/src/utils/StopGenerationDetector.ts index c136f31b..2c0b641f 100644 --- a/src/utils/StopGenerationDetector.ts +++ b/src/utils/StopGenerationDetector.ts @@ -248,7 +248,7 @@ export class StopGenerationDetector { else if (value instanceof SpecialToken) return value.tokenize(tokenizer); else if (value instanceof SpecialTokensText) - return value.tokenize(tokenizer); + return value.tokenize(tokenizer, true); return value satisfies never; }) From 103795960fb606533493e71d2b93f903fbd34a13 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 23:37:56 +0300 Subject: [PATCH 38/52] fix: bugs --- src/chatWrappers/LlamaChatWrapper.ts | 9 ++++++--- src/cli/commands/CompleteCommand.ts | 1 - src/cli/commands/InfillCommand.ts | 1 - src/utils/LlamaText.ts | 7 ++++--- src/utils/parseModelFileName.ts | 6 +++++- 5 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/chatWrappers/LlamaChatWrapper.ts b/src/chatWrappers/LlamaChatWrapper.ts index 71b17b5c..f79c092e 100644 --- a/src/chatWrappers/LlamaChatWrapper.ts +++ b/src/chatWrappers/LlamaChatWrapper.ts @@ -9,7 +9,7 @@ export class LlamaChatWrapper extends ChatWrapper { /** @internal */ private readonly _addSpaceBeforeEos: boolean; public constructor({ - addSpaceBeforeEos = true + addSpaceBeforeEos = false }: { /** * Default to `true` @@ -26,7 +26,8 @@ export class LlamaChatWrapper extends ChatWrapper { documentFunctionParams?: boolean } = {}): { contextText: LlamaText, - stopGenerationTriggers: LlamaText[] + stopGenerationTriggers: LlamaText[], + ignoreStartText?: LlamaText[] } { const historyWithFunctions = this.addAvailableFunctionsSystemMessageToHistory(history, availableFunctions, { documentParams: documentFunctionParams @@ -120,8 +121,10 @@ export class LlamaChatWrapper extends ChatWrapper { /** @internal */ public static override _getOptionConfigurationsToTestIfCanSupersedeJinjaTemplate() { - return [{}, { + return [{ addSpaceBeforeEos: false + }, { + addSpaceBeforeEos: true }] satisfies Partial[0]>[]; } } diff --git a/src/cli/commands/CompleteCommand.ts b/src/cli/commands/CompleteCommand.ts index 0d1ff92a..da126f6d 100644 --- a/src/cli/commands/CompleteCommand.ts +++ b/src/cli/commands/CompleteCommand.ts @@ -4,7 +4,6 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; -import bytes from "bytes"; import {getLlama} from "../../bindings/getLlama.js"; import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; diff --git a/src/cli/commands/InfillCommand.ts b/src/cli/commands/InfillCommand.ts index aefbb882..f7bc18f1 100644 --- a/src/cli/commands/InfillCommand.ts +++ b/src/cli/commands/InfillCommand.ts @@ -4,7 +4,6 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; -import bytes from "bytes"; import {getLlama} from "../../bindings/getLlama.js"; import {LlamaLogLevel, LlamaLogLevelGreaterThan} from "../../bindings/types.js"; import {LlamaCompletion} from "../../evaluator/LlamaCompletion.js"; diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index c53ef21a..2d0b1320 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -204,19 +204,20 @@ const LlamaTextPrototypeFunctions: Partial = { tokenize(this: LlamaText, tokenizer): Token[] { let textToTokenize = ""; const res: Token[] = []; + const hasContent = () => (res.length > 0 || textToTokenize.length > 0); for (const value of this.values) { if (value instanceof SpecialToken) { - res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer)); + res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined), ...value.tokenize(tokenizer)); textToTokenize = ""; } else if (value instanceof SpecialTokensText) { - res.push(...tokenizer(textToTokenize, false), ...value.tokenize(tokenizer, res.length > 0 || textToTokenize.length > 0)); + res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined), ...value.tokenize(tokenizer, hasContent())); textToTokenize = ""; } else textToTokenize += value; } - res.push(...tokenizer(textToTokenize, false)); + res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined)); return res; }, diff --git a/src/utils/parseModelFileName.ts b/src/utils/parseModelFileName.ts index 37634cf1..e2a3ceb2 100644 --- a/src/utils/parseModelFileName.ts +++ b/src/utils/parseModelFileName.ts @@ -20,6 +20,7 @@ export function parseModelFileName(filename: string) { const {previousParts, parameters, nextParts} = splitByModelParameters(parts); const name = previousParts.shift(); + const otherInfo: string[] = []; for (let i = 0; i < nextParts.length; i++) { const part = nextParts[i]; @@ -31,6 +32,8 @@ export function parseModelFileName(filename: string) { version = part.toLowerCase(); nextParts.splice(i, 1); i--; + } else { + otherInfo.push(part); } } @@ -41,7 +44,8 @@ export function parseModelFileName(filename: string) { fileType, version, contextSize, - parameters + parameters, + otherInfo }; } From e4ccbfffe15fd0fa53805308ceeac876abaa5802 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 23:39:07 +0300 Subject: [PATCH 39/52] feat: add `noJinja` and `noTrimWhitespace` flags to the `chat` command --- src/chatWrappers/utils/resolveChatWrapper.ts | 18 +++++-- src/cli/commands/ChatCommand.ts | 53 +++++++++++++++----- 2 files changed, 54 insertions(+), 17 deletions(-) diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts index 2fc76ebc..f82c0d8b 100644 --- a/src/chatWrappers/utils/resolveChatWrapper.ts +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -66,7 +66,12 @@ export type ResolveChatWrapperOptions = { [wrapper in keyof typeof chatWrappers]?: ConstructorParameters<(typeof chatWrappers)[wrapper]>[0] }, warningLogs?: boolean, - fallbackToOtherWrappersOnJinjaError?: boolean + fallbackToOtherWrappersOnJinjaError?: boolean, + + /** + * Don't resolve to a Jinja chat wrapper unless `type` is set to a Jinja chat wrapper type. + */ + noJinja?: boolean }; /** @@ -90,7 +95,8 @@ export function resolveChatWrapper({ tokenizer, customWrapperSettings, warningLogs = true, - fallbackToOtherWrappersOnJinjaError = true + fallbackToOtherWrappersOnJinjaError = true, + noJinja = false }: ResolveChatWrapperOptions) { function createSpecializedChatWrapper( specializedChatWrapper: T, @@ -150,7 +156,7 @@ export function resolveChatWrapper({ const modelJinjaTemplate = customWrapperSettings?.jinjaTemplate?.template ?? fileInfo?.metadata?.tokenizer?.chat_template; - if (modelJinjaTemplate != null && modelJinjaTemplate.trim() !== "") { + if (!noJinja && modelJinjaTemplate != null && modelJinjaTemplate.trim() !== "") { const jinjaTemplateChatWrapperOptions: JinjaTemplateChatWrapperOptions = { ...(customWrapperSettings?.jinjaTemplate ?? {}), template: modelJinjaTemplate @@ -200,12 +206,14 @@ export function resolveChatWrapper({ } if (filename != null) { - const {name, subType, fileType} = parseModelFileName(filename); + const {name, subType, fileType, otherInfo} = parseModelFileName(filename); if (fileType?.toLowerCase() === "gguf") { const lowercaseName = name?.toLowerCase(); const lowercaseSubType = subType?.toLowerCase(); - const splitLowercaseSubType = lowercaseSubType?.split("-") ?? []; + const splitLowercaseSubType = (lowercaseSubType?.split("-") ?? []).concat( + otherInfo.map(info => info.toLowerCase()) + ); const firstSplitLowercaseSubType = splitLowercaseSubType[0]; if (lowercaseName === "llama") { diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index 08a9b241..ec05aa60 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -4,7 +4,6 @@ import path from "path"; import {CommandModule} from "yargs"; import chalk from "chalk"; import fs from "fs-extra"; -import bytes from "bytes"; import {chatCommandHistoryFilePath, defaultChatSystemPrompt} from "../../config.js"; import {getIsInDocumentationMode} from "../../state.js"; import {ReplHistory} from "../../utils/ReplHistory.js"; @@ -31,8 +30,10 @@ type ChatCommand = { prompt?: string, promptFile?: string, wrapper: SpecializedChatWrapperTypeName | "auto", + noJinja?: boolean, contextSize?: number, batchSize?: number, + noTrimWhitespace: boolean, grammar: "text" | Parameters[1], jsonSchemaGrammarFile?: string, threads: number, @@ -108,6 +109,12 @@ export const ChatCommand: CommandModule = { description: "Chat wrapper to use. Use `auto` to automatically select a wrapper based on the model's BOS token", group: "Optional:" }) + .option("noJinja", { + type: "boolean", + default: false, + description: "Don't use a Jinja wrapper, even if it's the best option for the model", + group: "Optional:" + }) .option("contextSize", { alias: "c", type: "number", @@ -122,6 +129,13 @@ export const ChatCommand: CommandModule = { description: "Batch size to use for the model context. The default value is the context size", group: "Optional:" }) + .option("noTrimWhitespace", { + type: "boolean", + alias: ["noTrim"], + default: false, + description: "Don't trim whitespaces from the model response", + group: "Optional:" + }) .option("grammar", { alias: "g", type: "string", @@ -255,17 +269,17 @@ export const ChatCommand: CommandModule = { }, async handler({ model, systemInfo, systemPrompt, systemPromptFile, prompt, - promptFile, wrapper, contextSize, batchSize, - grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, + promptFile, wrapper, noJinja, contextSize, batchSize, + noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings }) { try { await RunChat({ - model, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize, - grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, - repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, + model, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize, batchSize, + noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, + lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings }); } catch (err) { @@ -278,14 +292,16 @@ export const ChatCommand: CommandModule = { async function RunChat({ - model: modelArg, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize, - grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature, minP, topK, topP, gpuLayers, - lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, + model: modelArg, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize, batchSize, + noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature, minP, topK, topP, + gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings }: ChatCommand) { if (contextSize === -1) contextSize = undefined; if (gpuLayers === -1) gpuLayers = undefined; + const trimWhitespace = !noTrimWhitespace; + if (debug) console.info(`${chalk.yellow("Log level:")} debug`); @@ -372,7 +388,8 @@ async function RunChat({ bosString: model.tokens.bosString, filename: model.filename, fileInfo: model.fileInfo, - tokenizer: model.tokenize + tokenizer: model.tokenize, + noJinja }) ?? new GeneralChatWrapper(); const contextSequence = context.getSequence(); const session = new LlamaChatSession({ @@ -457,6 +474,7 @@ async function RunChat({ // eslint-disable-next-line no-constant-condition while (true) { + let hadNoWhitespaceTextInThisIteration = false; const input = initialPrompt != null ? initialPrompt : await getPrompt(); @@ -494,11 +512,22 @@ async function RunChat({ ? undefined : maxTokens, onToken(chunk) { - process.stdout.write(model.detokenize(chunk)); + const text = model.detokenize(chunk); + + if (trimWhitespace && !hadNoWhitespaceTextInThisIteration) { + const trimmedText = text.trimStart(); + + if (trimmedText.length > 0) { + process.stdout.write(trimmedText); + hadNoWhitespaceTextInThisIteration = true; + } + } else + process.stdout.write(text); }, functions: (grammar == null && environmentFunctions) ? defaultEnvironmentFunctions - : undefined + : undefined, + trimWhitespaceSuffix: trimWhitespace }); process.stdout.write(endColor); console.log(); From a61ba629ddd65ccfaafb1fd7eed35ebed238f3ea Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 23:39:50 +0300 Subject: [PATCH 40/52] test: update `parseModelFileName.test.ts` --- test/standalone/parseModelFileName.test.ts | 36 +++++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/test/standalone/parseModelFileName.test.ts b/test/standalone/parseModelFileName.test.ts index 246835fd..4e4b840b 100644 --- a/test/standalone/parseModelFileName.test.ts +++ b/test/standalone/parseModelFileName.test.ts @@ -10,7 +10,8 @@ describe("parseModelFileName", () => { quantization: "Q4_K_M", fileType: "gguf", version: undefined, - parameters: "13B" + parameters: "13B", + otherInfo: [] }); }); @@ -22,7 +23,8 @@ describe("parseModelFileName", () => { quantization: "Q4_K_M", fileType: "gguf", version: "v2", - parameters: "13B" + parameters: "13B", + otherInfo: [] }); }); @@ -34,7 +36,8 @@ describe("parseModelFileName", () => { quantization: "Q4_K_M", fileType: "gguf", version: "v2", - parameters: "34B" + parameters: "34B", + otherInfo: [] }); }); @@ -46,7 +49,8 @@ describe("parseModelFileName", () => { subType: "llama-2", quantization: "Q5_K_S", fileType: "gguf", - parameters: "13B" + parameters: "13B", + otherInfo: [] }); }); @@ -56,7 +60,8 @@ describe("parseModelFileName", () => { name: "functionary", subType: "small-v2.2", quantization: "q4_0", - fileType: "gguf" + fileType: "gguf", + otherInfo: [] }); }); @@ -67,7 +72,8 @@ describe("parseModelFileName", () => { subType: "alpaca", quantization: "Q5_K_M", fileType: "gguf", - parameters: "13B" + parameters: "13B", + otherInfo: [] }); }); @@ -78,7 +84,8 @@ describe("parseModelFileName", () => { subType: "2.1-mistral", quantization: "Q4_K_M", fileType: "gguf", - parameters: "7B" + parameters: "7B", + otherInfo: [] }); }); @@ -89,7 +96,20 @@ describe("parseModelFileName", () => { subType: "", quantization: "Q5_K_M", fileType: "gguf", - parameters: "7B" + parameters: "7B", + otherInfo: ["it"] + }); + }); + + test("llama-2-7b-chat.Q4_0.gguf", () => { + expect(parseModelFileName("llama-2-7b-chat.Q4_0.gguf")) + .toEqual({ + name: "llama", + subType: "2", + quantization: "Q4_0", + fileType: "gguf", + parameters: "7B", + otherInfo: ["chat"] }); }); }); From eecf2c3b92567cb0a3ac53de3bd8841a6ef0ff18 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Tue, 2 Apr 2024 23:40:01 +0300 Subject: [PATCH 41/52] test: move files --- .../chatWrappers/{generic => }/LlamaChatPromptWrapper.test.ts | 4 ++-- .../chatWrappers/{ => generic}/TemplateChatWrapper.test.ts | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) rename test/standalone/chatWrappers/{generic => }/LlamaChatPromptWrapper.test.ts (97%) rename test/standalone/chatWrappers/{ => generic}/TemplateChatWrapper.test.ts (99%) diff --git a/test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts b/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts similarity index 97% rename from test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts rename to test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts index 3e10a066..7bb478ca 100644 --- a/test/standalone/chatWrappers/generic/LlamaChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts @@ -1,6 +1,6 @@ import {describe, expect, test} from "vitest"; -import {ChatHistoryItem, LlamaChatWrapper} from "../../../../src/index.js"; -import {defaultChatSystemPrompt} from "../../../../src/config.js"; +import {ChatHistoryItem, LlamaChatWrapper} from "../../../src/index.js"; +import {defaultChatSystemPrompt} from "../../../src/config.js"; describe("LlamaChatWrapper", () => { diff --git a/test/standalone/chatWrappers/TemplateChatWrapper.test.ts b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts similarity index 99% rename from test/standalone/chatWrappers/TemplateChatWrapper.test.ts rename to test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts index 58e5215f..71306ac0 100644 --- a/test/standalone/chatWrappers/TemplateChatWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts @@ -1,6 +1,6 @@ import {describe, expect, test} from "vitest"; -import {ChatHistoryItem, TemplateChatWrapper} from "../../../src/index.js"; -import {defaultChatSystemPrompt} from "../../../src/config.js"; +import {ChatHistoryItem, TemplateChatWrapper} from "../../../../src/index.js"; +import {defaultChatSystemPrompt} from "../../../../src/config.js"; describe("TemplateChatWrapper", () => { From c95f2fe08bf217e06ad5ada137d50b68211b888a Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 01:04:14 +0300 Subject: [PATCH 42/52] fix: bugs --- src/chatWrappers/utils/resolveChatWrapper.ts | 2 + src/cli/commands/ChatCommand.ts | 26 ++-- src/evaluator/LlamaCompletion.ts | 12 +- src/utils/LlamaText.ts | 11 +- src/utils/getQueuedTokensBeforeStopTrigger.ts | 6 +- src/utils/tokenizeInput.ts | 6 +- .../functionary/chatSession.test.ts | 36 +++-- .../functionary/grammar.test.ts | 2 +- .../modelDependent/functionary/sanity.test.ts | 2 +- .../FalconChatPromptWrapper.test.ts | 74 +++------- .../chatWrappers/GemmaChatWrapper.test.ts | 56 +++----- .../GeneralChatPromptWrapper.test.ts | 134 +++++------------- .../LlamaChatPromptWrapper.test.ts | 78 +++------- .../generic/JinjaTemplateChatWrapper.test.ts | 122 +++++++--------- .../generic/TemplateChatWrapper.test.ts | 98 ++++++------- test/standalone/parseModelFileName.test.ts | 24 ++++ 16 files changed, 276 insertions(+), 413 deletions(-) diff --git a/src/chatWrappers/utils/resolveChatWrapper.ts b/src/chatWrappers/utils/resolveChatWrapper.ts index f82c0d8b..ce80afd0 100644 --- a/src/chatWrappers/utils/resolveChatWrapper.ts +++ b/src/chatWrappers/utils/resolveChatWrapper.ts @@ -239,6 +239,8 @@ export function resolveChatWrapper({ return createSpecializedChatWrapper(ChatMLChatWrapper); else if (lowercaseName === "gemma") return createSpecializedChatWrapper(GemmaChatWrapper); + else if (splitLowercaseSubType.includes("chatml")) + return createSpecializedChatWrapper(ChatMLChatWrapper); } } diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index ec05aa60..5225f72f 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -475,6 +475,7 @@ async function RunChat({ // eslint-disable-next-line no-constant-condition while (true) { let hadNoWhitespaceTextInThisIteration = false; + let nextPrintLeftovers = ""; const input = initialPrompt != null ? initialPrompt : await getPrompt(); @@ -512,17 +513,26 @@ async function RunChat({ ? undefined : maxTokens, onToken(chunk) { - const text = model.detokenize(chunk); + let text = nextPrintLeftovers + model.detokenize(chunk); + nextPrintLeftovers = ""; - if (trimWhitespace && !hadNoWhitespaceTextInThisIteration) { - const trimmedText = text.trimStart(); + if (trimWhitespace) { + if (!hadNoWhitespaceTextInThisIteration) { + text = text.trimStart(); - if (trimmedText.length > 0) { - process.stdout.write(trimmedText); - hadNoWhitespaceTextInThisIteration = true; + if (text.length > 0) + hadNoWhitespaceTextInThisIteration = true; } - } else - process.stdout.write(text); + + const textWithTrimmedEnd = text.trimEnd(); + + if (textWithTrimmedEnd.length < text.length) { + nextPrintLeftovers = text.slice(textWithTrimmedEnd.length); + text = textWithTrimmedEnd; + } + } + + process.stdout.write(text); }, functions: (grammar == null && environmentFunctions) ? defaultEnvironmentFunctions diff --git a/src/evaluator/LlamaCompletion.ts b/src/evaluator/LlamaCompletion.ts index a3da92bc..86845f9f 100644 --- a/src/evaluator/LlamaCompletion.ts +++ b/src/evaluator/LlamaCompletion.ts @@ -243,7 +243,13 @@ export class LlamaCompletion { if (this._sequence == null || this.disposed) throw new DisposedError(); - const resolvedInput = tokenizeInput(input, this._sequence.model.tokenize); + const resolvedInput = tokenizeInput( + input, + this._sequence.model.tokenize, + (shouldPrependBosToken && bosToken != null) + ? "trimLeadingSpace" + : undefined + ); const resolvedContextShiftSize = await resolveContextShiftSize(contextShiftSize, this._sequence); ensureNotAborted(); @@ -420,8 +426,8 @@ export class LlamaCompletion { if (this._sequence == null || this.disposed) throw new DisposedError(); - const resolvedPrefixInputTokens = tokenizeInput(prefixInput, this._sequence.model.tokenize); - const resolvedSuffixInputTokens = tokenizeInput(suffixInput, this._sequence.model.tokenize); + const resolvedPrefixInputTokens = tokenizeInput(prefixInput, this._sequence.model.tokenize, "trimLeadingSpace"); + const resolvedSuffixInputTokens = tokenizeInput(suffixInput, this._sequence.model.tokenize, "trimLeadingSpace"); const resolvedContextShiftSize = await resolveContextShiftSize(contextShiftSize, this._sequence); ensureNotAborted(); diff --git a/src/utils/LlamaText.ts b/src/utils/LlamaText.ts index 2d0b1320..e2d999c9 100644 --- a/src/utils/LlamaText.ts +++ b/src/utils/LlamaText.ts @@ -21,7 +21,7 @@ export type LlamaText = { joinValues(separator: LlamaText | V): LlamaText, toString(): string, toJSON(): LlamaTextJSON, - tokenize(tokenizer: Tokenizer): Token[], + tokenize(tokenizer: Tokenizer, options?: "trimLeadingSpace"): Token[], compare(other: LlamaText): boolean, trimStart(): LlamaText, trimEnd(): LlamaText, @@ -201,23 +201,24 @@ const LlamaTextPrototypeFunctions: Partial = { }) .join(""); }, - tokenize(this: LlamaText, tokenizer): Token[] { + tokenize(this: LlamaText, tokenizer, options?: "trimLeadingSpace"): Token[] { let textToTokenize = ""; const res: Token[] = []; const hasContent = () => (res.length > 0 || textToTokenize.length > 0); + const resolveTokenizerOptions = () => (hasContent() ? "trimLeadingSpace" : options); for (const value of this.values) { if (value instanceof SpecialToken) { - res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined), ...value.tokenize(tokenizer)); + res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions()), ...value.tokenize(tokenizer)); textToTokenize = ""; } else if (value instanceof SpecialTokensText) { - res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined), ...value.tokenize(tokenizer, hasContent())); + res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions()), ...value.tokenize(tokenizer, hasContent() || options === "trimLeadingSpace")); textToTokenize = ""; } else textToTokenize += value; } - res.push(...tokenizer(textToTokenize, false, hasContent() ? "trimLeadingSpace" : undefined)); + res.push(...tokenizer(textToTokenize, false, resolveTokenizerOptions())); return res; }, diff --git a/src/utils/getQueuedTokensBeforeStopTrigger.ts b/src/utils/getQueuedTokensBeforeStopTrigger.ts index 64abb2c4..63fecede 100644 --- a/src/utils/getQueuedTokensBeforeStopTrigger.ts +++ b/src/utils/getQueuedTokensBeforeStopTrigger.ts @@ -14,7 +14,7 @@ export function getQueuedTokensBeforeStopTrigger( else if (partiallyFreeTokens.tokens.length !== 0 && partiallyFreeTokens.text.length === 0) return partiallyFreeTokens.tokens; else if (partiallyFreeTokens.tokens.length === 0 && partiallyFreeTokens.text.length !== 0) - return tokenizer(partiallyFreeTokens.text); + return tokenizer(partiallyFreeTokens.text, false, "trimLeadingSpace"); const triggerThatStartsWithStringIndex = triggeredStops.findIndex( (trigger) => trigger.stopTrigger.length > 0 && typeof trigger.stopTrigger[0] === "string" @@ -26,9 +26,9 @@ export function getQueuedTokensBeforeStopTrigger( if (triggerThatStartsWithTokenIndex > 0 && triggerThatStartsWithStringIndex < 0) return partiallyFreeTokens.tokens; else if (triggerThatStartsWithStringIndex > 0 && triggerThatStartsWithTokenIndex < 0) - return tokenizer(partiallyFreeTokens.text); + return tokenizer(partiallyFreeTokens.text, false, "trimLeadingSpace"); - const stringTokens = tokenizer(partiallyFreeTokens.text); + const stringTokens = tokenizer(partiallyFreeTokens.text, false, "trimLeadingSpace"); if (stringTokens.length === partiallyFreeTokens.tokens.length && stringTokens.every((value, index) => value === partiallyFreeTokens.tokens[index]) ) diff --git a/src/utils/tokenizeInput.ts b/src/utils/tokenizeInput.ts index 4d6408fb..8e16dc34 100644 --- a/src/utils/tokenizeInput.ts +++ b/src/utils/tokenizeInput.ts @@ -1,11 +1,11 @@ import {Token, Tokenizer} from "../types.js"; import {isLlamaText, LlamaText} from "./LlamaText.js"; -export function tokenizeInput(input: Token[] | string | LlamaText, tokenizer: Tokenizer) { +export function tokenizeInput(input: Token[] | string | LlamaText, tokenizer: Tokenizer, options?: "trimLeadingSpace") { if (typeof input === "string") - return tokenizer(input, false); + return tokenizer(input, false, options); else if (isLlamaText(input)) - return input.tokenize(tokenizer); + return input.tokenize(tokenizer, options); return input; } diff --git a/test/modelDependent/functionary/chatSession.test.ts b/test/modelDependent/functionary/chatSession.test.ts index 228ee1ba..7aa033ee 100644 --- a/test/modelDependent/functionary/chatSession.test.ts +++ b/test/modelDependent/functionary/chatSession.test.ts @@ -1,5 +1,5 @@ import {describe, expect, test} from "vitest"; -import {LlamaChatSession} from "../../../src/index.js"; +import {FunctionaryChatWrapper, LlamaChatSession} from "../../../src/index.js"; import {getModelFile} from "../../utils/modelFiles.js"; import {getTestLlama} from "../../utils/getTestLlama.js"; @@ -19,9 +19,11 @@ describe("functionary", () => { contextSequence: context.getSequence() }); + expect(chatSession.chatWrapper).to.be.an.instanceof(FunctionaryChatWrapper); + const res = await chatSession.prompt("How much is 6+6"); - expect(res).to.eql("6+6 equals 12."); + expect(res).to.eql("The sum of 6 and 6 is 12."); const chatHistory = chatSession.getChatHistory(); @@ -33,7 +35,7 @@ describe("functionary", () => { const res2 = await chatSession2.prompt("Repeat your answer"); - expect(res2).to.eql("6+6 equals 12."); + expect(res2).to.eql("The sum of 6 and 6 is 12."); }); test("disposing a context sequences removes the current state", {timeout: 1000 * 60 * 60 * 2}, async () => { @@ -52,14 +54,16 @@ describe("functionary", () => { autoDisposeSequence: false }); + expect(chatSession.chatWrapper).to.be.an.instanceof(FunctionaryChatWrapper); + const res = await chatSession.prompt("How much is 6+6"); - expect(res).to.eql("6+6 equals 12."); + expect(res).to.eql("The sum of 6 and 6 is 12."); const tokenMeterState = contextSequence.tokenMeter.getState(); expect(tokenMeterState).to.toMatchInlineSnapshot(` { - "usedInputTokens": 140, - "usedOutputTokens": 17, + "usedInputTokens": 96, + "usedOutputTokens": 14, "usedRestoreStateTokens": 0, } `); @@ -77,13 +81,13 @@ describe("functionary", () => { const tokenMeterState2 = contextSequence2.tokenMeter.getState(); expect(tokenMeterState2).to.toMatchInlineSnapshot(` { - "usedInputTokens": 142, - "usedOutputTokens": 19, + "usedInputTokens": 98, + "usedOutputTokens": 15, "usedRestoreStateTokens": 0, } `); expect(tokenMeterState2.usedInputTokens).to.be.greaterThanOrEqual(tokenMeterState.usedInputTokens); - expect(res2).to.eql("6+6+6 equals 18."); + expect(res2).to.eql("The sum of 6+6+6 is 18."); }); test("reusing a context sequences utilizes existing state", {timeout: 1000 * 60 * 60 * 2}, async () => { @@ -102,14 +106,16 @@ describe("functionary", () => { autoDisposeSequence: false }); + expect(chatSession.chatWrapper).to.be.an.instanceof(FunctionaryChatWrapper); + const res = await chatSession.prompt("How much is 6+6"); - expect(res).to.eql("6+6 equals 12."); + expect(res).to.eql("The sum of 6 and 6 is 12."); const tokenMeterState = contextSequence.tokenMeter.getState(); expect(tokenMeterState).to.toMatchInlineSnapshot(` { - "usedInputTokens": 140, - "usedOutputTokens": 17, + "usedInputTokens": 96, + "usedOutputTokens": 14, "usedRestoreStateTokens": 0, } `); @@ -124,13 +130,13 @@ describe("functionary", () => { const tokenMeterStateDiff = contextSequence.tokenMeter.diff(tokenMeterState); expect(tokenMeterStateDiff).to.toMatchInlineSnapshot(` { - "usedInputTokens": 25, - "usedOutputTokens": 19, + "usedInputTokens": 10, + "usedOutputTokens": 15, "usedRestoreStateTokens": 0, } `); expect(tokenMeterStateDiff.usedInputTokens).to.be.lessThan(tokenMeterState.usedInputTokens); - expect(res2).to.eql("6+6+6 equals 18."); + expect(res2).to.eql("The sum of 6+6+6 is 18."); }); }); }); diff --git a/test/modelDependent/functionary/grammar.test.ts b/test/modelDependent/functionary/grammar.test.ts index fa4f0655..6f25df6a 100644 --- a/test/modelDependent/functionary/grammar.test.ts +++ b/test/modelDependent/functionary/grammar.test.ts @@ -34,7 +34,7 @@ describe("functionary", () => { } } as const); - const res = await chatSession.prompt("How's your day going so far?", { + const res = await chatSession.prompt("How's your great day going so far?", { grammar }); const parsedRes = grammar.parse(res); diff --git a/test/modelDependent/functionary/sanity.test.ts b/test/modelDependent/functionary/sanity.test.ts index 6a8e5ca8..9440b5f9 100644 --- a/test/modelDependent/functionary/sanity.test.ts +++ b/test/modelDependent/functionary/sanity.test.ts @@ -21,7 +21,7 @@ describe("functionary", () => { const res = await chatSession.prompt("How much is 6+6"); - expect(res).to.eql("6+6 equals 12."); + expect(res).to.eql("The sum of 6 and 6 is 12."); }); test("text is tokenized with special tokens when appropriate", {timeout: 1000 * 60 * 60 * 2}, async () => { diff --git a/test/standalone/chatWrappers/FalconChatPromptWrapper.test.ts b/test/standalone/chatWrappers/FalconChatPromptWrapper.test.ts index 3aaee42d..d22d74e7 100644 --- a/test/standalone/chatWrappers/FalconChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/FalconChatPromptWrapper.test.ts @@ -38,22 +38,15 @@ describe("FalconChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "User: ", - "Hi there!", - " + User: Hi there! - ", - "Assistant: ", - "Hello!", + Assistant: Hello!", ] `); @@ -63,32 +56,19 @@ describe("FalconChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " - - ", - "User: ", - "Hi there!", - " - - ", - "Assistant: ", - "Hello!", - " - - ", - "User: ", - "How are you?", - " - - ", - "Assistant: ", - "I'm good, how are you?", + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. + + User: Hi there! + + Assistant: Hello! + + User: How are you? + + Assistant: I'm good, how are you?", ] `); @@ -105,48 +85,32 @@ describe("FalconChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "User: ", - "Hi there!", - " + User: Hi there! - ", - "Assistant: ", - "Hello!", + Assistant: Hello!", ] `); expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "User: ", - "Hi there!", - " + User: Hi there! - ", - "Assistant: ", - "Hello!", - " + Assistant: Hello! - ", - "Assistant: ", + Assistant: ", ] `); }); diff --git a/test/standalone/chatWrappers/GemmaChatWrapper.test.ts b/test/standalone/chatWrappers/GemmaChatWrapper.test.ts index d69ba4d2..badb68f4 100644 --- a/test/standalone/chatWrappers/GemmaChatWrapper.test.ts +++ b/test/standalone/chatWrappers/GemmaChatWrapper.test.ts @@ -38,7 +38,7 @@ describe("GemmaChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "user ", }, @@ -49,13 +49,9 @@ describe("GemmaChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "model + model ", }, "Hello!", @@ -67,7 +63,7 @@ describe("GemmaChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "user ", }, @@ -78,35 +74,23 @@ describe("GemmaChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "model + model ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "user + user ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "model + model ", }, "I'm good, how are you?", @@ -125,7 +109,7 @@ describe("GemmaChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "user ", }, @@ -136,13 +120,9 @@ describe("GemmaChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "model + model ", }, "Hello!", @@ -152,7 +132,7 @@ describe("GemmaChatWrapper", () => { expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "user ", }, @@ -163,13 +143,9 @@ describe("GemmaChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "model + model ", }, "Hello! diff --git a/test/standalone/chatWrappers/GeneralChatPromptWrapper.test.ts b/test/standalone/chatWrappers/GeneralChatPromptWrapper.test.ts index 6a38df4c..d4f7904b 100644 --- a/test/standalone/chatWrappers/GeneralChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/GeneralChatPromptWrapper.test.ts @@ -38,24 +38,17 @@ describe("GeneralChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Human - ", - "Hi there!", - " + ### Human + Hi there! - ", - "### Assistant - ", - "Hello!", + ### Assistant + Hello!", ] `); @@ -65,36 +58,23 @@ describe("GeneralChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Human - ", - "Hi there!", - " + ### Human + Hi there! - ", - "### Assistant - ", - "Hello!", - " + ### Assistant + Hello! - ", - "### Human - ", - "How are you?", - " + ### Human + How are you? - ", - "### Assistant - ", - "I'm good, how are you?", + ### Assistant + I'm good, how are you?", ] `); @@ -111,52 +91,36 @@ describe("GeneralChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Human - ", - "Hi there!", - " + ### Human + Hi there! - ", - "### Assistant - ", - "Hello!", + ### Assistant + Hello!", ] `); expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Human - ", - "Hi there!", - " + ### Human + Hi there! - ", - "### Assistant - ", - "Hello!", - " + ### Assistant + Hello! - ", - "### Assistant + ### Assistant ", ] `); @@ -172,24 +136,17 @@ describe("GeneralChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Instruction - ", - "Hi there!", - " + ### Instruction + Hi there! - ", - "### Response - ", - "Hello!", + ### Response + Hello!", ] `); @@ -202,36 +159,23 @@ describe("GeneralChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. - If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", - " + If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information. - ", - "### Instruction - ", - "Hi there!", - " + ### Instruction + Hi there! - ", - "### Response - ", - "Hello!", - " + ### Response + Hello! - ", - "### Instruction - ", - "How are you?", - " + ### Instruction + How are you? - ", - "### Response - ", - "I'm good, how are you?", + ### Response + I'm good, how are you?", ] `); }); diff --git a/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts b/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts index 7bb478ca..7f85a32b 100644 --- a/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/LlamaChatPromptWrapper.test.ts @@ -38,23 +38,18 @@ describe("LlamaChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", - "value": "[INST] ", - }, - { - "type": "specialToken", - "value": "<> + "type": "specialTokensText", + "value": "[INST] <> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> @@ -62,10 +57,8 @@ describe("LlamaChatWrapper", () => { }, "Hi there!", { - "type": "specialToken", - "value": " [/INST] - - ", + "type": "specialTokensText", + "value": " [/INST] ", }, "Hello!", ] @@ -77,23 +70,18 @@ describe("LlamaChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", - "value": "[INST] ", - }, - { - "type": "specialToken", - "value": "<> + "type": "specialTokensText", + "value": "[INST] <> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> @@ -101,32 +89,26 @@ describe("LlamaChatWrapper", () => { }, "Hi there!", { - "type": "specialToken", - "value": " [/INST] - - ", + "type": "specialTokensText", + "value": " [/INST] ", }, "Hello!", { - "builtin": true, "type": "specialToken", "value": "EOS", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "How are you?", { - "type": "specialToken", - "value": " [/INST] - - ", + "type": "specialTokensText", + "value": " [/INST] ", }, "I'm good, how are you?", ] @@ -145,23 +127,18 @@ describe("LlamaChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", - "value": "[INST] ", - }, - { - "type": "specialToken", - "value": "<> + "type": "specialTokensText", + "value": "[INST] <> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> @@ -169,10 +146,8 @@ describe("LlamaChatWrapper", () => { }, "Hi there!", { - "type": "specialToken", - "value": " [/INST] - - ", + "type": "specialTokensText", + "value": " [/INST] ", }, "Hello!", ] @@ -181,23 +156,18 @@ describe("LlamaChatWrapper", () => { expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", - "value": "[INST] ", - }, - { - "type": "specialToken", - "value": "<> + "type": "specialTokensText", + "value": "[INST] <> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> @@ -205,10 +175,8 @@ describe("LlamaChatWrapper", () => { }, "Hi there!", { - "type": "specialToken", - "value": " [/INST] - - ", + "type": "specialTokensText", + "value": " [/INST] ", }, "Hello! diff --git a/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts b/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts index 93d5e1c9..129d93a3 100644 --- a/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/JinjaTemplateChatWrapper.test.ts @@ -135,31 +135,30 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -169,18 +168,16 @@ describe("JinjaTemplateChatWrapper", () => { [ LlamaText [ { - "builtin": true, "type": "specialToken", "value": "EOS", }, ], LlamaText [ { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, { - "builtin": true, "type": "specialToken", "value": "EOS", }, @@ -193,55 +190,52 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, { - "builtin": true, "type": "specialToken", "value": "EOS", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "I'm good, how are you?", @@ -260,31 +254,30 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -294,31 +287,30 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello! @@ -332,41 +324,38 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText4.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, { - "builtin": true, "type": "specialToken", "value": "EOS", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST]", }, ] @@ -382,12 +371,11 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. @@ -395,7 +383,7 @@ describe("JinjaTemplateChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST]", }, "Hello!", @@ -412,12 +400,11 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. @@ -425,7 +412,7 @@ describe("JinjaTemplateChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -443,12 +430,11 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "System: You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. @@ -456,7 +442,7 @@ describe("JinjaTemplateChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -474,14 +460,14 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> @@ -491,24 +477,23 @@ describe("JinjaTemplateChatWrapper", () => { "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -527,7 +512,7 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "<> ", }, @@ -551,24 +536,23 @@ describe("JinjaTemplateChatWrapper", () => { After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. The assistant calls the functions in advance before telling the user about the result", { - "type": "specialToken", + "type": "specialTokensText", "value": " <> ", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello!", @@ -591,12 +575,11 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "System: The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. @@ -618,32 +601,30 @@ describe("JinjaTemplateChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello! [[call: func2({"message":"Hello","feeling":"good","words":1})]] [[result: {"yes":true,"message":"ok"}]]", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, { - "builtin": true, "type": "specialToken", "value": "EOS", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST]", }, ] @@ -665,12 +646,11 @@ describe("JinjaTemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "System: The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. @@ -692,7 +672,7 @@ describe("JinjaTemplateChatWrapper", () => { Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST] ", }, "Hello! @@ -701,26 +681,24 @@ describe("JinjaTemplateChatWrapper", () => { Function result: {"yes":true,"message":"ok"} ", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, { - "builtin": true, "type": "specialToken", "value": "EOS", }, { - "builtin": true, "type": "specialToken", "value": "BOS", }, { - "type": "specialToken", + "type": "specialTokensText", "value": "[INST] ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " [/INST]", }, ] diff --git a/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts index 71306ac0..97942869 100644 --- a/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts +++ b/test/standalone/chatWrappers/generic/TemplateChatWrapper.test.ts @@ -106,19 +106,19 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "SYS: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -131,35 +131,31 @@ describe("TemplateChatWrapper", () => { expect(contextText2.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "SYS: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model: ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "user: ", + user: ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -179,19 +175,19 @@ describe("TemplateChatWrapper", () => { expect(contextText3.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "SYS: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -202,19 +198,19 @@ describe("TemplateChatWrapper", () => { expect(contextText3WithOpenModelResponse.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "SYS: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -229,29 +225,25 @@ describe("TemplateChatWrapper", () => { expect(contextText4.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "SYS: user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model: ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "user: ", + user: ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, @@ -272,19 +264,19 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "BEGIN system: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -306,19 +298,19 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "system: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -342,7 +334,7 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "system: ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. @@ -365,13 +357,13 @@ describe("TemplateChatWrapper", () => { After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. The assistant calls the functions in advance before telling the user about the result", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model:", }, @@ -399,7 +391,7 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "system: ", }, "The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. @@ -419,30 +411,26 @@ describe("TemplateChatWrapper", () => { After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. The assistant calls the functions in advance before telling the user about the result", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model: ", }, "Hello! [[call: func2({"message":"Hello","feeling":"good","words":1})]] [[result: {"yes":true,"message":"ok"}]]", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "user: ", + user: ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, @@ -469,7 +457,7 @@ describe("TemplateChatWrapper", () => { expect(contextText.values).toMatchInlineSnapshot(` [ { - "type": "specialToken", + "type": "specialTokensText", "value": "system: ", }, "The assistant calls the provided functions as needed to retrieve information instead of relying on things it already knows. @@ -489,13 +477,13 @@ describe("TemplateChatWrapper", () => { After calling a function the result will appear afterwards and be visible only to the assistant, so the assistant has to tell the user about it outside of the function call context. The assistant calls the functions in advance before telling the user about the result", { - "type": "specialToken", + "type": "specialTokensText", "value": " user: ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": " model: ", }, @@ -505,17 +493,13 @@ describe("TemplateChatWrapper", () => { Function result: {"yes":true,"message":"ok"} ", { - "type": "specialToken", + "type": "specialTokensText", "value": " - ", - }, - { - "type": "specialToken", - "value": "user: ", + user: ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": " ", }, diff --git a/test/standalone/parseModelFileName.test.ts b/test/standalone/parseModelFileName.test.ts index 4e4b840b..1561aa0a 100644 --- a/test/standalone/parseModelFileName.test.ts +++ b/test/standalone/parseModelFileName.test.ts @@ -112,4 +112,28 @@ describe("parseModelFileName", () => { otherInfo: ["chat"] }); }); + + test("rpguild-chatml-13b.Q4_K_M.gguf", () => { + expect(parseModelFileName("rpguild-chatml-13b.Q4_K_M.gguf")) + .toEqual({ + name: "rpguild", + subType: "chatml", + quantization: "Q4_K_M", + fileType: "gguf", + parameters: "13B", + otherInfo: [] + }); + }); + + test("neuralbeagle14-7b.Q4_K_M.gguf", () => { + expect(parseModelFileName("neuralbeagle14-7b.Q4_K_M.gguf")) + .toEqual({ + name: "neuralbeagle14", + subType: "", + quantization: "Q4_K_M", + fileType: "gguf", + parameters: "7B", + otherInfo: [] + }); + }); }); From dc5dfca23eeacdca12a09bc7808cda361480288a Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 01:06:54 +0300 Subject: [PATCH 43/52] test: fix tests --- .../ChatMLChatPromptWrapper.test.ts | 104 +++++++----------- 1 file changed, 38 insertions(+), 66 deletions(-) diff --git a/test/standalone/chatWrappers/ChatMLChatPromptWrapper.test.ts b/test/standalone/chatWrappers/ChatMLChatPromptWrapper.test.ts index 3aeb05b7..67bfa74a 100644 --- a/test/standalone/chatWrappers/ChatMLChatPromptWrapper.test.ts +++ b/test/standalone/chatWrappers/ChatMLChatPromptWrapper.test.ts @@ -39,30 +39,26 @@ describe("ChatMLChatWrapper", () => { [ { "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", "value": "<|im_start|>system ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>user + <|im_start|>user ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, "Hello!", @@ -76,52 +72,40 @@ describe("ChatMLChatWrapper", () => { [ { "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", "value": "<|im_start|>system ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>user + <|im_start|>user ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>user + <|im_start|>user ", }, "How are you?", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, "I'm good, how are you?", @@ -142,30 +126,26 @@ describe("ChatMLChatWrapper", () => { [ { "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", "value": "<|im_start|>system ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>user + <|im_start|>user ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, "Hello!", @@ -176,41 +156,33 @@ describe("ChatMLChatWrapper", () => { [ { "type": "specialToken", + "value": "BOS", + }, + { + "type": "specialTokensText", "value": "<|im_start|>system ", }, "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible. If a question does not make any sense, or is not factually coherent, explain why instead of answering something incorrectly. If you don't know the answer to a question, don't share false information.", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>user + <|im_start|>user ", }, "Hi there!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, "Hello!", { - "type": "specialToken", + "type": "specialTokensText", "value": "<|im_end|> - ", - }, - { - "type": "specialToken", - "value": "<|im_start|>assistant + <|im_start|>assistant ", }, ] From 547cfa832b218c708aa7ea3e8730f57eed4d4426 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 01:23:57 +0300 Subject: [PATCH 44/52] fix: macOS build --- llama/gpuInfo/metal-gpu-info.h | 1 + llama/gpuInfo/metal-gpu-info.mm | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/llama/gpuInfo/metal-gpu-info.h b/llama/gpuInfo/metal-gpu-info.h index 07b0d56b..30056ce7 100644 --- a/llama/gpuInfo/metal-gpu-info.h +++ b/llama/gpuInfo/metal-gpu-info.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include void getMetalGpuInfo(uint64_t * total, uint64_t * used); diff --git a/llama/gpuInfo/metal-gpu-info.mm b/llama/gpuInfo/metal-gpu-info.mm index 61c7d510..7bfd6bce 100644 --- a/llama/gpuInfo/metal-gpu-info.mm +++ b/llama/gpuInfo/metal-gpu-info.mm @@ -1,5 +1,6 @@ #include #include +#include #import void getMetalGpuInfo(uint64_t * total, uint64_t * used) { @@ -19,11 +20,11 @@ void getMetalGpuInfo(uint64_t * total, uint64_t * used) { void getMetalGpuDeviceNames(std::vector * deviceNames) { NSArray> *devices = MTLCopyAllDevices(); - + for (id device in devices) { (*deviceNames).push_back(std::string(([NSString stringWithUTF8String:device.name.UTF8String]).UTF8String)); } - + [devices release]; devices = nil; } From 367e043620f645c257326aa5464e93fa26ff91c0 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 02:03:46 +0300 Subject: [PATCH 45/52] fix: update `lifecycle-utils` to improve `splitText`'s runtime efficiency --- package-lock.json | 8 ++++---- package.json | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/package-lock.json b/package-lock.json index 6b62664d..b79095c2 100644 --- a/package-lock.json +++ b/package-lock.json @@ -20,7 +20,7 @@ "cross-spawn": "^7.0.3", "env-var": "^7.3.1", "fs-extra": "^11.2.0", - "lifecycle-utils": "^1.4.0", + "lifecycle-utils": "^1.4.1", "log-symbols": "^5.1.0", "node-addon-api": "^7.0.0", "octokit": "^3.1.0", @@ -7905,9 +7905,9 @@ } }, "node_modules/lifecycle-utils": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/lifecycle-utils/-/lifecycle-utils-1.4.0.tgz", - "integrity": "sha512-t2Fg6vKtUjPhhFFWFLeaJlQZrEzOo9bUPvSdKQK7hO/MM9Kfo2s1EoA5SfVtFKLA5yiO5myEq9DJ6qYlOypXUQ==" + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/lifecycle-utils/-/lifecycle-utils-1.4.1.tgz", + "integrity": "sha512-l8itA/+LnqlgMWM5AuSanjZk+S0+Ia9TldZPd9JHy4bCfrk1lUmNWKgt+xTuDqKy1sCI0dKZ7234R+wpVcBGUg==" }, "node_modules/lines-and-columns": { "version": "1.2.4", diff --git a/package.json b/package.json index 36eabb65..5f96b4c8 100644 --- a/package.json +++ b/package.json @@ -160,7 +160,7 @@ "cross-spawn": "^7.0.3", "env-var": "^7.3.1", "fs-extra": "^11.2.0", - "lifecycle-utils": "^1.4.0", + "lifecycle-utils": "^1.4.1", "log-symbols": "^5.1.0", "node-addon-api": "^7.0.0", "octokit": "^3.1.0", From ccae9fed8834b22cfb89df0e891306321c6bf697 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 02:56:21 +0300 Subject: [PATCH 46/52] test: add sensible timeouts --- test/modelDependent/functionary/gguf/ggufInsights.test.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/modelDependent/functionary/gguf/ggufInsights.test.ts b/test/modelDependent/functionary/gguf/ggufInsights.test.ts index bc3787b5..d329ab42 100644 --- a/test/modelDependent/functionary/gguf/ggufInsights.test.ts +++ b/test/modelDependent/functionary/gguf/ggufInsights.test.ts @@ -74,7 +74,7 @@ describe("gguf", async () => { `); }); - test("predicted VRAM usage should match actual VRAM usage", async (testContext) => { + test("predicted VRAM usage should match actual VRAM usage", {timeout: 1000 * 60 * 5}, async (testContext) => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); @@ -127,7 +127,7 @@ describe("gguf", async () => { await model.dispose(); }); - test("predicted VRAM usage should match actual VRAM usage when using gpuLayers", async (context) => { + test("predicted VRAM usage should match actual VRAM usage when using gpuLayers", {timeout: 1000 * 60 * 5}, async (context) => { const llama = await getTestLlama(); const ggufMetadataParseResult = await readGgufFileInfo(modelPath); From 6ec6829979270c812710b091c65c6546088dc2d2 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 03:08:25 +0300 Subject: [PATCH 47/52] fix: vitest config --- vitest.config.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/vitest.config.ts b/vitest.config.ts index ff41ab8f..f90683cc 100644 --- a/vitest.config.ts +++ b/vitest.config.ts @@ -5,6 +5,7 @@ export default defineConfig({ pool: "forks", maxWorkers: 1, minWorkers: 1, + maxConcurrency: 1, poolOptions: { threads: { minThreads: 1, From 8f088762fabe61bd671a6b5a061d96078c365cfb Mon Sep 17 00:00:00 2001 From: Gilad S Date: Wed, 3 Apr 2024 21:52:20 +0300 Subject: [PATCH 48/52] feat: `inspect measure` command --- .vitepress/config.ts | 1 + docs/guide/cli/cli.data.ts | 4 + docs/guide/cli/inspect/measure.md | 17 + package-lock.json | 5 +- package.json | 1 + src/cli/commands/ChatCommand.ts | 5 +- src/cli/commands/CompleteCommand.ts | 5 +- src/cli/commands/InfillCommand.ts | 5 +- src/cli/commands/inspect/InspectCommand.ts | 4 +- .../inspect/commands/InspectGgufCommand.ts | 4 +- .../inspect/commands/InspectMeasureCommand.ts | 716 ++++++++++++++++++ src/cli/utils/ConsoleTable.ts | 132 ++++ src/cli/utils/printCommonInfoLines.ts | 12 +- src/cli/utils/resolveCommandGgufPath.ts | 6 + src/gguf/GgufInsights.ts | 8 +- 15 files changed, 909 insertions(+), 16 deletions(-) create mode 100644 docs/guide/cli/inspect/measure.md create mode 100644 src/cli/commands/inspect/commands/InspectMeasureCommand.ts create mode 100644 src/cli/utils/ConsoleTable.ts create mode 100644 src/cli/utils/resolveCommandGgufPath.ts diff --git a/.vitepress/config.ts b/.vitepress/config.ts index cd11d87b..f6739798 100644 --- a/.vitepress/config.ts +++ b/.vitepress/config.ts @@ -206,6 +206,7 @@ export default defineConfig({ items: [ {text: "GPU", link: "/inspect/gpu"}, {text: "GGUF", link: "/inspect/gguf"}, + {text: "Measure", link: "/inspect/measure"}, ] }, {text: "Build", link: "/build"}, diff --git a/docs/guide/cli/cli.data.ts b/docs/guide/cli/cli.data.ts index 9161f0cc..cd7274ae 100644 --- a/docs/guide/cli/cli.data.ts +++ b/docs/guide/cli/cli.data.ts @@ -14,6 +14,7 @@ import {cliBinName, npxRunPrefix} from "../../../src/config.js"; import {buildHtmlHeading} from "../../../.vitepress/utils/buildHtmlHeading.js"; import {buildHtmlTable} from "../../../.vitepress/utils/buildHtmlTable.js"; import {setIsInDocumentationMode} from "../../../src/state.js"; +import {InspectMeasureCommand} from "../../../src/cli/commands/inspect/commands/InspectMeasureCommand.js"; export default { async load() { @@ -42,6 +43,9 @@ export default { }), gguf: await getCommandHtmlDoc(InspectGgufCommand, { parentCommand: InspectCommand + }), + measure: await getCommandHtmlDoc(InspectMeasureCommand, { + parentCommand: InspectCommand }) }, download: await getCommandHtmlDoc(DownloadCommand), diff --git a/docs/guide/cli/inspect/measure.md b/docs/guide/cli/inspect/measure.md new file mode 100644 index 00000000..737a0067 --- /dev/null +++ b/docs/guide/cli/inspect/measure.md @@ -0,0 +1,17 @@ +--- +outline: deep +--- +# `inspect measure` command + + + +{{commandDoc.description}} + +## Usage +```shell-vue +{{commandDoc.usage}} +``` +
diff --git a/package-lock.json b/package-lock.json index b79095c2..c10403d8 100644 --- a/package-lock.json +++ b/package-lock.json @@ -28,6 +28,7 @@ "proper-lockfile": "^4.1.2", "semver": "^7.6.0", "simple-git": "^3.19.1", + "slice-ansi": "^7.1.0", "strip-ansi": "^7.1.0", "uuid": "^9.0.0", "which": "^4.0.0", @@ -6585,7 +6586,6 @@ "version": "1.2.0", "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.2.0.tgz", "integrity": "sha512-2nk+7SIVb14QrgXFHcm84tD4bKQz0RxPuMT8Ag5KPOq7J5fEmAg0UbXdTOSHqNuHSU28k55qnceesxXRZGzKWA==", - "dev": true, "engines": { "node": ">=18" }, @@ -13419,7 +13419,6 @@ "version": "7.1.0", "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-7.1.0.tgz", "integrity": "sha512-bSiSngZ/jWeX93BqeIAbImyTbEihizcwNjFoRUIY/T1wWQsfsm2Vw1agPKylXvQTU7iASGdHhyqRlqQzfz+Htg==", - "dev": true, "dependencies": { "ansi-styles": "^6.2.1", "is-fullwidth-code-point": "^5.0.0" @@ -13435,7 +13434,6 @@ "version": "6.2.1", "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz", "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==", - "dev": true, "engines": { "node": ">=12" }, @@ -13447,7 +13445,6 @@ "version": "5.0.0", "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-5.0.0.tgz", "integrity": "sha512-OVa3u9kkBbw7b8Xw5F9P+D/T9X+Z4+JruYVNapTjPYZYUznQ5YfWeFkOj606XYYW8yugTfC8Pj0hYqvi4ryAhA==", - "dev": true, "dependencies": { "get-east-asian-width": "^1.0.0" }, diff --git a/package.json b/package.json index 5f96b4c8..6bf44e30 100644 --- a/package.json +++ b/package.json @@ -168,6 +168,7 @@ "proper-lockfile": "^4.1.2", "semver": "^7.6.0", "simple-git": "^3.19.1", + "slice-ansi": "^7.1.0", "strip-ansi": "^7.1.0", "uuid": "^9.0.0", "which": "^4.0.0", diff --git a/src/cli/commands/ChatCommand.ts b/src/cli/commands/ChatCommand.ts index 5225f72f..18c779e7 100644 --- a/src/cli/commands/ChatCommand.ts +++ b/src/cli/commands/ChatCommand.ts @@ -21,6 +21,7 @@ import { } from "../../chatWrappers/utils/resolveChatWrapper.js"; import {GeneralChatWrapper} from "../../chatWrappers/GeneralChatWrapper.js"; import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; +import {resolveCommandGgufPath} from "../utils/resolveCommandGgufPath.js"; type ChatCommand = { model: string, @@ -305,6 +306,8 @@ async function RunChat({ if (debug) console.info(`${chalk.yellow("Log level:")} debug`); + const resolvedModelPath = await resolveCommandGgufPath(modelArg); + const llamaLogLevel = debug ? LlamaLogLevel.debug : LlamaLogLevel.warn; @@ -344,7 +347,7 @@ async function RunChat({ }, async () => { try { return await llama.loadModel({ - modelPath: path.resolve(process.cwd(), modelArg), + modelPath: resolvedModelPath, gpuLayers: gpuLayers != null ? gpuLayers : undefined }); } finally { diff --git a/src/cli/commands/CompleteCommand.ts b/src/cli/commands/CompleteCommand.ts index da126f6d..ce7a8265 100644 --- a/src/cli/commands/CompleteCommand.ts +++ b/src/cli/commands/CompleteCommand.ts @@ -11,6 +11,7 @@ import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; +import {resolveCommandGgufPath} from "../utils/resolveCommandGgufPath.js"; type CompleteCommand = { model: string, @@ -217,6 +218,8 @@ async function RunCompletion({ if (debug) console.info(`${chalk.yellow("Log level:")} debug`); + const resolvedModelPath = await resolveCommandGgufPath(modelArg); + const llamaLogLevel = debug ? LlamaLogLevel.debug : LlamaLogLevel.warn; @@ -249,7 +252,7 @@ async function RunCompletion({ }, async () => { try { return await llama.loadModel({ - modelPath: path.resolve(process.cwd(), modelArg), + modelPath: resolvedModelPath, gpuLayers: gpuLayers != null ? gpuLayers : undefined }); } finally { diff --git a/src/cli/commands/InfillCommand.ts b/src/cli/commands/InfillCommand.ts index f7bc18f1..c925d309 100644 --- a/src/cli/commands/InfillCommand.ts +++ b/src/cli/commands/InfillCommand.ts @@ -11,6 +11,7 @@ import withOra from "../../utils/withOra.js"; import {TokenMeter} from "../../evaluator/TokenMeter.js"; import {printInfoLine} from "../utils/printInfoLine.js"; import {printCommonInfoLines} from "../utils/printCommonInfoLines.js"; +import {resolveCommandGgufPath} from "../utils/resolveCommandGgufPath.js"; type InfillCommand = { model: string, @@ -229,6 +230,8 @@ async function RunInfill({ if (debug) console.info(`${chalk.yellow("Log level:")} debug`); + const resolvedModelPath = await resolveCommandGgufPath(modelArg); + const llamaLogLevel = debug ? LlamaLogLevel.debug : LlamaLogLevel.warn; @@ -275,7 +278,7 @@ async function RunInfill({ }, async () => { try { return await llama.loadModel({ - modelPath: path.resolve(process.cwd(), modelArg), + modelPath: resolvedModelPath, gpuLayers: gpuLayers != null ? gpuLayers : undefined }); } finally { diff --git a/src/cli/commands/inspect/InspectCommand.ts b/src/cli/commands/inspect/InspectCommand.ts index 504c4e2e..4275f70d 100644 --- a/src/cli/commands/inspect/InspectCommand.ts +++ b/src/cli/commands/inspect/InspectCommand.ts @@ -1,6 +1,7 @@ import {CommandModule} from "yargs"; import {InspectGgufCommand} from "./commands/InspectGgufCommand.js"; import {InspectGpuCommand} from "./commands/InspectGpuCommand.js"; +import {InspectMeasureCommand} from "./commands/InspectMeasureCommand.js"; type InspectCommand = { // no options for now @@ -12,7 +13,8 @@ export const InspectCommand: CommandModule = { builder(yargs) { return yargs .command(InspectGpuCommand) - .command(InspectGgufCommand); + .command(InspectGgufCommand) + .command(InspectMeasureCommand); }, async handler() { // this function must exit, even though we do nothing here diff --git a/src/cli/commands/inspect/commands/InspectGgufCommand.ts b/src/cli/commands/inspect/commands/InspectGgufCommand.ts index e003c800..1ccf9e2f 100644 --- a/src/cli/commands/inspect/commands/InspectGgufCommand.ts +++ b/src/cli/commands/inspect/commands/InspectGgufCommand.ts @@ -25,7 +25,7 @@ export const InspectGgufCommand: CommandModule = { .option("path", { type: "string", demandOption: true, - description: "The path to the GGUF file to inspect" + description: "The path or URL of the GGUF file to inspect. If a URL is provided, the metadata will be read from the remote file without downloading the entire file." }) .option("fullTensorInfo", { alias: "t", @@ -66,7 +66,7 @@ export const InspectGgufCommand: CommandModule = { console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); } - const parsedMetadata = await readGgufFileInfo(ggufPath, {ignoreKeys: []}); + const parsedMetadata = await readGgufFileInfo(ggufPath); const fileTypeName = getGgufFileTypeName(parsedMetadata.metadata.general?.file_type); if (plainJson || outputToJsonFile != null) { diff --git a/src/cli/commands/inspect/commands/InspectMeasureCommand.ts b/src/cli/commands/inspect/commands/InspectMeasureCommand.ts new file mode 100644 index 00000000..95ef6463 --- /dev/null +++ b/src/cli/commands/inspect/commands/InspectMeasureCommand.ts @@ -0,0 +1,716 @@ +import path from "path"; +import process from "process"; +import {fileURLToPath} from "url"; +import {fork} from "node:child_process"; +import {CommandModule} from "yargs"; +import chalk from "chalk"; +import bytes from "bytes"; +import stripAnsi from "strip-ansi"; +import {readGgufFileInfo} from "../../../../gguf/readGgufFileInfo.js"; +import {resolveCommandGgufPath} from "../../../utils/resolveCommandGgufPath.js"; +import {getLlama} from "../../../../bindings/getLlama.js"; +import {LlamaLogLevel} from "../../../../bindings/types.js"; +import {LlamaModel} from "../../../../evaluator/LlamaModel.js"; +import {getConsoleLogPrefix} from "../../../../utils/getConsoleLogPrefix.js"; +import {ConsoleTable, ConsoleTableColumn} from "../../../utils/ConsoleTable.js"; +import {GgufInsights} from "../../../../gguf/GgufInsights.js"; + +type InspectMeasureCommand = { + path: string, + minLayers: number, + maxLayers?: number, + minContextSize: number, + maxContextSize?: number, + measures: number, + printHeaderBeforeEachLayer?: boolean +}; + +export const InspectMeasureCommand: CommandModule = { + command: "measure [path]", + describe: "Measure VRAM consumption of a GGUF model file with all possible combinations of gpu layers and context sizes", + builder(yargs) { + return yargs + .option("path", { + type: "string", + demandOption: true, + description: "The path of the GGUF model file to measure" + }) + .option("minLayers", { + alias: "mnl", + type: "number", + default: 1, + description: "Minimum number of layers to offload to the GPU", + group: "Optional:" + }) + .option("maxLayers", { + alias: "mxl", + type: "number", + default: -1, + defaultDescription: "All layers", + description: "Maximum number of layers to offload to the GPU", + group: "Optional:" + }) + .option("minContextSize", { + alias: "mncs", + type: "number", + default: 512, + description: "Minimum context size", + group: "Optional:" + }) + .option("maxContextSize", { + alias: "mxcs", + type: "number", + default: -1, + defaultDescription: "Train context size", + description: "Maximum context size", + group: "Optional:" + }) + .option("measures", { + alias: "m", + type: "number", + default: 10, + description: "Number of context size measures to take for each gpu layers count", + group: "Optional:" + }) + .option("printHeaderBeforeEachLayer", { + alias: "ph", + type: "boolean", + default: true, + description: "Print header before each layer's measures", + group: "Optional:" + }); + }, + async handler({ + path: ggufPath, minLayers, maxLayers, minContextSize, maxContextSize, measures = 10, printHeaderBeforeEachLayer = true + }: InspectMeasureCommand) { + if (maxLayers === -1) maxLayers = undefined; + if (maxContextSize === -1) maxContextSize = undefined; + if (minLayers < 1) minLayers = 1; + + const resolvedGgufPath = await resolveCommandGgufPath(ggufPath); + + // ensure a llama build is available + const llama = await getLlama("lastBuild", { + logLevel: LlamaLogLevel.error + }); + + console.info(`${chalk.yellow("File:")} ${resolvedGgufPath}`); + console.info(); + + const ggufMetadata = await readGgufFileInfo(resolvedGgufPath, { + sourceType: "filesystem" + }); + const ggufInsights = await GgufInsights.from(ggufMetadata, llama); + const totalVram = llama.getVramState().total; + + let lastGpuLayers = maxLayers ?? ggufInsights.totalLayers; + let previousContextSizeCheck: undefined | number = undefined; + + measureTable.logHeader({drawRowSeparator: !printHeaderBeforeEachLayer}); + + while (lastGpuLayers >= (minLayers ?? 0)) { + let printedAlreadyWithThisProcess = false; + let hadSuccessInThisProcess = false; + const getNewProccessValue = () => { + if (printedAlreadyWithThisProcess) + return undefined; + + printedAlreadyWithThisProcess = true; + return chalk.green("*"); + }; + + const done = await measureModel({ + modelPath: resolvedGgufPath, + maxGpuLayers: lastGpuLayers, + minGpuLayers: minLayers, + initialMaxContextSize: previousContextSizeCheck, + maxContextSize, + minContextSize, + tests: measures, + onInfo({gpuLayers, result}) { + if (lastGpuLayers !== gpuLayers) { + lastGpuLayers = gpuLayers; + previousContextSizeCheck = undefined; + measureTable.logLine({}); + + if (printHeaderBeforeEachLayer) + measureTable.logHeader({drawRowSeparator: false}); + } + + if (result.type === "crash") { + if (!hadSuccessInThisProcess) { + measureTable.logLine({ + newProcess: getNewProccessValue(), + type: chalk.redBright("Crash"), + gpuLayers: String(lastGpuLayers), + contextSize: previousContextSizeCheck != null + ? String(previousContextSizeCheck) + : chalk.red(result.result), + estimatedModelVram: previousContextSizeCheck == null + ? undefined + : chalk.red(result.result) + }); + lastGpuLayers--; + } + } else if (result.type === "error") { + previousContextSizeCheck = result.contextSize; + hadSuccessInThisProcess = true; + + measureTable.logLine({ + newProcess: getNewProccessValue(), + type: chalk.red("Error"), + gpuLayers: String(lastGpuLayers), + contextSize: previousContextSizeCheck != null + ? String(previousContextSizeCheck) + : chalk.red(result.error), + estimatedModelVram: previousContextSizeCheck == null + ? undefined + : chalk.red(result.error) + }); + } else if (result.type === "success") { + previousContextSizeCheck = result.contextSize; + hadSuccessInThisProcess = true; + + const modelVramEstimation = ggufInsights.estimateModelResourceRequirements({gpuLayers: lastGpuLayers}).gpuVram; + const modelVramEstimationDiffBytes = (result.modelVramUsage < modelVramEstimation ? "-" : "") + + bytes(Math.abs(result.modelVramUsage - modelVramEstimation)); + const modelVramEstimationDiffText = modelVramEstimationDiffBytes.padEnd(9, " ") + " " + + padStartAnsi("(" + renderDiffPercentageWithColors(((modelVramEstimation / result.modelVramUsage) - 1) * 100) + ")", 9); + + const contextVramEstimation = previousContextSizeCheck == null + ? undefined + : ggufInsights.estimateContextResourceRequirements({ + contextSize: previousContextSizeCheck, + modelGpuLayers: lastGpuLayers + }).gpuVram; + const contextVramEstimationDiffBytes = (result.contextVramUsage == null || contextVramEstimation == null) + ? undefined + : ( + (result.contextVramUsage < contextVramEstimation ? "-" : "") + + bytes(Math.abs(result.contextVramUsage - contextVramEstimation)) + ); + const contextVramEstimationDiffText = ( + contextVramEstimation == null || contextVramEstimationDiffBytes == null || result.contextVramUsage == null + ) + ? undefined + : ( + contextVramEstimationDiffBytes.padEnd(9, " ") + " " + + padStartAnsi("(" + renderDiffPercentageWithColors(((contextVramEstimation / result.contextVramUsage) - 1) * 100) + ")", 9) + ); + + measureTable.logLine({ + newProcess: getNewProccessValue(), + type: previousContextSizeCheck == null + ? "Model" + : "Context", + gpuLayers: String(lastGpuLayers), + contextSize: previousContextSizeCheck != null + ? String(previousContextSizeCheck) + : undefined, + + estimatedModelVram: bytes(modelVramEstimation), + actualModelVram: bytes(result.modelVramUsage), + modelEstimationDiff: modelVramEstimationDiffText, + + estimatedContextVram: contextVramEstimation == null + ? undefined + : bytes(contextVramEstimation), + actualContextVram: result.contextVramUsage == null + ? undefined + : bytes(result.contextVramUsage), + contextEstimationDiff: contextVramEstimationDiffText, + totalVramUsage: ((result.totalVramUsage / totalVram) * 100).toFixed(2).padStart(5, "0") + "% " + + chalk.grey("(" + bytes(result.totalVramUsage) + "/" + bytes(totalVram) + ")") + }); + } + } + }); + + if (done) + break; + } + } +}; + +const measureTable = new ConsoleTable([{ + key: "newProcess", + title: " ", + width: 1 +}, { + key: "type", + title: "Type", + width: Math.max("Type".length, "Model".length, "Context".length), + canSpanOverEmptyColumns: true +}, { + key: "gpuLayers", + title: "Layers", + width: "Layers".length, + canSpanOverEmptyColumns: true +}, { + key: "contextSize", + title: "Context size", + width: "Context size".length, + canSpanOverEmptyColumns: true +}, { + key: "estimatedModelVram", + title: "Estimated model VRAM", + width: "Estimated model VRAM".length, + canSpanOverEmptyColumns: true +}, { + key: "actualModelVram", + title: "Model VRAM", + width: "Model VRAM".length +}, { + key: "modelEstimationDiff", + title: "Diff", + width: Math.max("Diff".length, 9 + 1 + 9) +}, { + key: "estimatedContextVram", + title: "Estimated context VRAM", + width: "Estimated context VRAM".length +}, { + key: "actualContextVram", + title: "Context VRAM", + width: "Context VRAM".length +}, { + key: "contextEstimationDiff", + title: "Diff", + width: Math.max("Diff".length, 9 + 1 + 9) +}, { + key: "totalVramUsage", + title: "VRAM usage", + width: Math.max("VRAM usage".length, 8 + 1 + 8 + 1 + 8) +}] as const satisfies readonly ConsoleTableColumn[]); + +function renderDiffPercentageWithColors(percentage: number, { + greenBright = 2, + green = 6, + yellow = 10, + yellowBright = 14 +}: { + greenBright?: number, + green?: number, + yellow?: number, + yellowBright?: number +} = {}): string { + const percentageText = percentage.toFixed(2).padStart(5, "0") + "%"; + const absPercentage = Math.abs(percentage); + + if (absPercentage < greenBright) + return chalk.greenBright(percentageText); + else if (absPercentage < green) + return chalk.green(percentageText); + else if (absPercentage < yellow) + return chalk.yellow(percentageText); + else if (absPercentage < yellowBright) + return chalk.yellowBright(percentageText); + + return chalk.red(percentageText); +} + +const __filename = fileURLToPath(import.meta.url); +const detectedFileName = path.basename(__filename); +const expectedFileName = "InspectMeasureCommand"; + +async function measureModel({modelPath, tests, initialMaxContextSize, maxContextSize, minContextSize, maxGpuLayers, minGpuLayers, onInfo}: { + modelPath: string, + tests: number, + initialMaxContextSize?: number, + maxContextSize?: number, + minContextSize?: number, + maxGpuLayers: number, + minGpuLayers?: number, + onInfo(data: { + gpuLayers: number, + result: { + type: "error", + error: string, + contextSize?: number + } | { + type: "crash", + result: string + } | { + type: "success", + modelVramUsage: number, + contextSize?: number, + contextVramUsage?: number, + totalVramUsage: number + } + }): void +}) { + if (!detectedFileName.startsWith(expectedFileName)) { + console.warn( + getConsoleLogPrefix() + + `"${expectedFileName}.js" file is not independent, so running sub-process tests cannot be done with it\n` + + getConsoleLogPrefix() + + 'To resolve this issue, make sure that "node-llama-cpp" is not bundled together with other code.' + ); + + throw new Error("Sub-process tests cannot be done with the current file"); + } + + const subProcess = fork(__filename, [], { + detached: false, + stdio: [null, null, null, "ipc"], + env: { + ...process.env, + MEASURE_MODEL_CP: "true" + } + }); + let isPlannedExit = false; + let forkSucceeded = false; + let timeoutHandle: ReturnType | null = null; + const processCreationTimeout = 1000 * 60 * 5; + const stdTexts: string[] = []; + + let lastGpuLayers = maxGpuLayers; + + function cleanup() { + if (subProcess.exitCode == null) + subProcess.kill("SIGKILL"); + + if (timeoutHandle != null) + clearTimeout(timeoutHandle); + + process.off("exit", cleanup); + } + + process.on("exit", cleanup); + + subProcess.stdout?.on("data", (data) => { + stdTexts.push(data.toString()); + }); + subProcess.stderr?.on("data", (data) => { + stdTexts.push(data.toString()); + }); + + return Promise.race([ + new Promise((_, reject) => { + timeoutHandle = setTimeout(() => { + if (!forkSucceeded) { + reject(new Error("Measuring using a sub-process timed out")); + cleanup(); + } + }, processCreationTimeout); + }), + new Promise((resolve, reject) => { + function done() { + if (!forkSucceeded) + reject(new Error(`Measuring a model failed to run a sub-process via file "${__filename}"`)); + else + resolve(isPlannedExit); + + cleanup(); + } + + subProcess.on("message", (message: ChildToParentMessage) => { + if (message.type === "ready") { + forkSucceeded = true; + subProcess.send({ + type: "start", + modelPath, + tests, + initialMaxContextSize, + maxContextSize, + minContextSize, + maxGpuLayers, + minGpuLayers + } satisfies ParentToChildMessage); + + if (timeoutHandle != null) { + clearTimeout(timeoutHandle); + timeoutHandle = null; + } + } else if (message.type === "done") { + isPlannedExit = true; + subProcess.send({type: "exit"} satisfies ParentToChildMessage); + } else if (message.type === "error") { + lastGpuLayers = message.gpuLayers; + + onInfo({ + gpuLayers: lastGpuLayers, + result: { + type: "error", + error: message.error, + contextSize: message.contextSize + } + }); + } else if (message.type === "stats") { + lastGpuLayers = message.gpuLayers; + + onInfo({ + gpuLayers: message.gpuLayers, + result: { + type: "success", + modelVramUsage: message.modelVramUsage, + contextSize: message.contextSize, + contextVramUsage: message.contextVramUsage, + totalVramUsage: message.totalVramUsage + } + }); + } + }); + + subProcess.on("exit", (code) => { + if (code !== 0 || !isPlannedExit) + onInfo({ + gpuLayers: lastGpuLayers, + result: { + type: "crash", + result: stdTexts.join("") + } + }); + + done(); + }); + + if (subProcess.killed || subProcess.exitCode != null) { + if (subProcess.exitCode !== 0 || !isPlannedExit) + onInfo({ + gpuLayers: lastGpuLayers, + result: { + type: "crash", + result: stdTexts.join("") + } + }); + + done(); + } + }) + ]); +} + +if (process.env.MEASURE_MODEL_CP === "true" && process.send != null) { + void runTestWorkerLogic(); +} + +async function runTestWorkerLogic() { + const llama = await getLlama("lastBuild", { + logLevel: LlamaLogLevel.error + }); + + if (process.send == null) + throw new Error("No IPC channel to parent process"); + + function sendInfoBack(info: ChildToParentMessage) { + if (process.send == null) + process.exit(1); + + process.send(info); + } + + async function testContextSizes({model, modelVramUsage, startContextSize, maxContextSize, minContextSize, tests}: { + model: LlamaModel, modelVramUsage: number, startContextSize?: number, maxContextSize?: number, minContextSize?: number, + tests: number + }) { + const contextSizeCheckPlan = getContextSizesCheckPlan( + maxContextSize != null + ? Math.min(model.trainContextSize, maxContextSize) + : model.trainContextSize, + tests, + minContextSize + ); + + let currentContextSizeCheck = startContextSize == null + ? -1 + : getNextItemInCheckContextSizesPlan(contextSizeCheckPlan, startContextSize); + + while (currentContextSizeCheck != null) { + if (currentContextSizeCheck === -1) + currentContextSizeCheck = null; + + try { + const preContextVramUsage = llama.getVramState().used; + const context = await model.createContext({ + contextSize: currentContextSizeCheck ?? undefined + }); + const postContextVramUsage = llama.getVramState().used; + + sendInfoBack({ + type: "stats", + gpuLayers: model.gpuLayers, + modelVramUsage, + contextSize: context.contextSize, + contextVramUsage: postContextVramUsage - preContextVramUsage, + totalVramUsage: postContextVramUsage + }); + currentContextSizeCheck = context.contextSize; + + await context.dispose(); + } catch (err) { + sendInfoBack({ + type: "error", + error: String(err), + gpuLayers: model.gpuLayers, + contextSize: currentContextSizeCheck == null + ? undefined + : currentContextSizeCheck + }); + + if (currentContextSizeCheck == null) { + currentContextSizeCheck = contextSizeCheckPlan[contextSizeCheckPlan.length - 1]; + continue; + } + } + + currentContextSizeCheck = getNextItemInCheckContextSizesPlan(contextSizeCheckPlan, currentContextSizeCheck); + } + } + + async function testWithGpuLayers({modelPath, gpuLayers, tests, startContextSize, maxContextSize, minContextSize}: { + modelPath: string, gpuLayers: number, tests: number, startContextSize?: number, maxContextSize?: number, minContextSize?: number + }) { + try { + const preModelVramUsage = llama.getVramState().used; + const model = await llama.loadModel({ + modelPath, + gpuLayers + }); + const postModelVramUsage = llama.getVramState().used; + + sendInfoBack({ + type: "stats", + gpuLayers: model.gpuLayers, + modelVramUsage: postModelVramUsage - preModelVramUsage, + totalVramUsage: postModelVramUsage + }); + + await testContextSizes({ + model, + modelVramUsage: postModelVramUsage - preModelVramUsage, + startContextSize, + maxContextSize, + minContextSize, + tests + }); + + await model.dispose(); + } catch (err) { + sendInfoBack({ + type: "error", + error: String(err), + gpuLayers: gpuLayers + }); + } + } + + process.on("message", async (message: ParentToChildMessage) => { + if (message.type === "start") { + for (let gpuLayers = message.maxGpuLayers; gpuLayers >= (message.minGpuLayers ?? 0); gpuLayers--) { + await testWithGpuLayers({ + modelPath: message.modelPath, + gpuLayers, + tests: message.tests, + startContextSize: gpuLayers == message.maxGpuLayers + ? message.initialMaxContextSize + : undefined, + maxContextSize: message.maxContextSize, + minContextSize: message.minContextSize + }); + } + + sendInfoBack({type: "done"}); + } else if (message.type === "exit") { + await llama.dispose(); + process.exit(0); + } + }); + + process.send({type: "ready"} satisfies ChildToParentMessage); +} + +function getContextSizesCheckPlan(trainContextSize: number, tests: number = 10, minContextSize?: number) { + const res: number[] = []; + let shouldStop = false; + + const attemptToCoverSizes = [256, 512, 1024, 2048, 4096] as const; + + function addSize(size: number) { + if (size > trainContextSize) { + size = trainContextSize; + shouldStop = true; + } + + if (size < 2) + size = 2; + + if (res[res.length - 1] === size) { + shouldStop = true; + return; + } + + res.push(size); + } + + while (!shouldStop && res.length < tests) { + const lastSize = res[res.length - 1]; + + if (lastSize == null) { + addSize(Math.max(minContextSize ?? 0, Math.min(attemptToCoverSizes[0], trainContextSize / tests))); + continue; + } + + const stepSizesLeft = Math.floor( + (trainContextSize - Math.min(lastSize, attemptToCoverSizes[attemptToCoverSizes.length - 1])) / (tests - res.length) + ); + + let stopAddingAttemptedSizes = false; + for (const size of attemptToCoverSizes) { + if (stepSizesLeft > lastSize && lastSize < size && size <= trainContextSize) { + addSize(size); + stopAddingAttemptedSizes = true; + break; + } + } + if (stopAddingAttemptedSizes) + continue; + + addSize(lastSize + stepSizesLeft); + } + + return res.reverse(); +} + +function getNextItemInCheckContextSizesPlan(plan: number[], currentSize: number) { + for (const size of plan) { + if (size < currentSize) + return size; + } + + return null; +} + +type ParentToChildMessage = { + type: "start", + modelPath: string, + tests: number, + maxGpuLayers: number, + minGpuLayers?: number, + initialMaxContextSize?: number, + maxContextSize?: number, + minContextSize?: number +} | { + type: "exit" +}; + +type ChildToParentMessage = { + type: "ready" | "done" +} | { + type: "stats", + gpuLayers: number, + modelVramUsage: number, + contextSize?: number, + contextVramUsage?: number, + totalVramUsage: number +} | { + type: "error", + error: string, + gpuLayers: number, + contextSize?: number +}; + +function padStartAnsi(text: string, length: number, padChar: string = " ") { + const textWithoutAnsi = stripAnsi(text); + + return padChar.repeat(Math.max(0, length - textWithoutAnsi.length)) + text; +} diff --git a/src/cli/utils/ConsoleTable.ts b/src/cli/utils/ConsoleTable.ts new file mode 100644 index 00000000..0a6df686 --- /dev/null +++ b/src/cli/utils/ConsoleTable.ts @@ -0,0 +1,132 @@ +import chalk from "chalk"; +import sliceAnsi from "slice-ansi"; +import stripAnsi from "strip-ansi"; + +export class ConsoleTable { + private readonly _columns: T; + private readonly _columnSeparator: string; + private readonly _drawHeaderRowSeparator: boolean; + + public constructor(columns: T, { + columnSeparator = chalk.grey(" | "), + drawHeaderRowSeparator = true + }: { + columnSeparator?: string, + drawHeaderRowSeparator?: boolean + } = {}) { + this._columns = columns; + this._columnSeparator = columnSeparator; + this._drawHeaderRowSeparator = drawHeaderRowSeparator; + } + + public logHeader({drawRowSeparator = this._drawHeaderRowSeparator}: {drawRowSeparator?: boolean} = {}) { + let logLine = ""; + + for (let i = 0; i < this._columns.length; i++) { + const column = this._columns[i]; + const canSpanOverEmptyColumns = column.canSpanOverEmptyColumns ?? false; + let title = column.title ?? " "; + let columnSize = getColumnWidth(column); + + title = toOneLine(title); + + title = (column.titleFormatter ?? defaultTitleFormatter)(title); + + while (title.length > columnSize && canSpanOverEmptyColumns && i < this._columns.length - 1) { + i++; + const nextColumn = this._columns[i]; + + if (nextColumn.title != null) { + i--; + break; + } + + columnSize += stripAnsi(this._columnSeparator).length + getColumnWidth(nextColumn); + } + + const moreText = "..."; + if (stripAnsi(title).length > columnSize) + title = sliceAnsi(title, 0, columnSize - moreText.length) + chalk.grey(moreText); + + title = title + " ".repeat(Math.max(0, columnSize - stripAnsi(title).length)); + title = sliceAnsi(title, 0, columnSize); + + if (i < this._columns.length - 1) + title += this._columnSeparator; + + logLine += title; + } + + console.info(logLine); + + if (drawRowSeparator) + console.info(chalk.grey("-".repeat(stripAnsi(logLine).length))); + } + + public logLine(data: {[key in T[number]["key"]]?: string}) { + let logLine = ""; + + for (let i = 0; i < this._columns.length; i++) { + const column = this._columns[i]; + let value = data[column.key as keyof typeof data]; + const canSpanOverEmptyColumns = column.canSpanOverEmptyColumns ?? false; + + if (value != null && column.valueFormatter != null) + value = column.valueFormatter(value); + + if (value == null) + value = ""; + + value = toOneLine(value); + + const valueWithoutAnsi = stripAnsi(value); + let columnSize = getColumnWidth(column); + + while (valueWithoutAnsi.length > columnSize && canSpanOverEmptyColumns && i < this._columns.length - 1) { + i++; + const nextColumn = this._columns[i]; + const nextValue = data[nextColumn.key as keyof typeof data]; + + if (nextValue != null) { + i--; + break; + } + + columnSize += stripAnsi(this._columnSeparator).length + getColumnWidth(nextColumn); + } + + const moreText = "..."; + if (valueWithoutAnsi.length > columnSize) + value = sliceAnsi(value, 0, columnSize - moreText.length) + chalk.grey(moreText); + + value = value + " ".repeat(Math.max(0, columnSize - valueWithoutAnsi.length)); + value = sliceAnsi(value, 0, columnSize); + + if (i < this._columns.length - 1) + value += this._columnSeparator; + + logLine += value; + } + + console.info(logLine); + } +} + +const defaultTitleFormatter = (value: string) => chalk.bold(value); + +export type ConsoleTableColumn = { + readonly key: K, + readonly title?: string, + readonly titleFormatter?: (value: string) => string, + readonly width?: number, + readonly valueFormatter?: (value: string) => string, + readonly canSpanOverEmptyColumns?: boolean +}; + +function getColumnWidth(column: ConsoleTableColumn) { + return column.width ?? stripAnsi(toOneLine(column.title ?? " ")).length; +} + +function toOneLine(text: string) { + return text.replaceAll("\n", chalk.grey("\\n")); +} diff --git a/src/cli/utils/printCommonInfoLines.ts b/src/cli/utils/printCommonInfoLines.ts index a6e9b54e..96ff0ea4 100644 --- a/src/cli/utils/printCommonInfoLines.ts +++ b/src/cli/utils/printCommonInfoLines.ts @@ -35,7 +35,7 @@ export function printCommonInfoLines({ value: bytes(llama.getVramState().total) }, { title: "Name", - value: llama.getGpuDeviceNames().join(", ") + value: toOneLine(llama.getGpuDeviceNames().join(", ")) }, { title: "GPU layers", value: `${model.gpuLayers}/${model.fileInsights.totalLayers} offloaded ${ @@ -49,18 +49,18 @@ export function printCommonInfoLines({ padTitle: padTitle, info: [{ title: "Type", - value: model.typeDescription + value: toOneLine(model.typeDescription) }, { title: "Size", value: bytes(model.size) }, { show: printBos, title: "BOS", - value: () => String(model.tokens.bosString) + value: () => toOneLine(String(model.tokens.bosString)) }, { show: printEos, title: "EOS", - value: () => String(model.tokens.eosString) + value: () => toOneLine(String(model.tokens.eosString)) }, { title: "Train context size", value: String(model.trainContextSize) @@ -83,3 +83,7 @@ export function printCommonInfoLines({ }] }); } + +function toOneLine(text: string) { + return text.replaceAll("\n", chalk.grey("\\n")); +} diff --git a/src/cli/utils/resolveCommandGgufPath.ts b/src/cli/utils/resolveCommandGgufPath.ts new file mode 100644 index 00000000..a7bf8717 --- /dev/null +++ b/src/cli/utils/resolveCommandGgufPath.ts @@ -0,0 +1,6 @@ +import path from "path"; +import process from "process"; + +export async function resolveCommandGgufPath(modelPath: string) { + return path.resolve(process.cwd(), modelPath); +} diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 17d74dbc..41c634c5 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -1,5 +1,6 @@ import {Llama} from "../bindings/Llama.js"; import {getLlamaWithoutBackend} from "../bindings/utils/getLlamaWithoutBackend.js"; +import {getDefaultContextBatchSize, getDefaultContextSequences} from "../evaluator/LlamaContext/LlamaContext.js"; import {GgufFileInfo} from "./types/GgufFileInfoTypes.js"; import {GgufTensorInfo} from "./types/GgufTensorInfoTypes.js"; import {GgufArchitectureType} from "./types/GgufMetadataTypes.js"; @@ -61,11 +62,14 @@ export class GgufInsights { * The estimation for the graph overhead memory will be improved in the future to be more precise, but it's good enough for now. */ public estimateContextResourceRequirements({ - contextSize, batchSize, modelGpuLayers, sequences, isEmbeddingContext = false, includeGraphOverhead = true + contextSize, modelGpuLayers, batchSize, sequences, isEmbeddingContext = false, includeGraphOverhead = true }: { - contextSize: number, batchSize: number, modelGpuLayers: number, sequences: number, isEmbeddingContext?: boolean, + contextSize: number, modelGpuLayers: number, batchSize?: number, sequences?: number, isEmbeddingContext?: boolean, includeGraphOverhead?: boolean }): GgufInsightsResourceRequirements { + if (sequences == null) sequences = getDefaultContextSequences(); + if (batchSize == null) batchSize = getDefaultContextBatchSize({contextSize, sequences}); + const actualContextSize = contextSize * sequences; const totalLayers = this.totalLayers; From ffb0eabeb223b3d3b98bcbb4f494acad5ac0bcef Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 4 Apr 2024 20:19:40 +0300 Subject: [PATCH 49/52] feat: improve VRAM consumption estimations --- llama/addon.cpp | 1 + src/bindings/AddonTypes.ts | 1 + .../inspect/commands/InspectMeasureCommand.ts | 63 +++++++++++--- src/gguf/GgufInsights.ts | 86 +++++++++++++++++-- 4 files changed, 133 insertions(+), 18 deletions(-) diff --git a/llama/addon.cpp b/llama/addon.cpp index 4651a6fe..d30fa817 100644 --- a/llama/addon.cpp +++ b/llama/addon.cpp @@ -1541,6 +1541,7 @@ Napi::Value addonGetConsts(const Napi::CallbackInfo& info) { consts.Set("ggmlMaxDims", Napi::Number::New(info.Env(), GGML_MAX_DIMS)); consts.Set("ggmlTypeF16Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F16))); consts.Set("ggmlTypeF32Size", Napi::Number::New(info.Env(), ggml_type_size(GGML_TYPE_F32))); + consts.Set("ggmlTensorOverhead", Napi::Number::New(info.Env(), ggml_tensor_overhead())); consts.Set("llamaMaxRngState", Napi::Number::New(info.Env(), LLAMA_MAX_RNG_STATE)); consts.Set("llamaPosSize", Napi::Number::New(info.Env(), sizeof(llama_pos))); consts.Set("llamaSeqIdSize", Napi::Number::New(info.Env(), sizeof(llama_seq_id))); diff --git a/src/bindings/AddonTypes.ts b/src/bindings/AddonTypes.ts index a7d7c075..37dcd71a 100644 --- a/src/bindings/AddonTypes.ts +++ b/src/bindings/AddonTypes.ts @@ -43,6 +43,7 @@ export type BindingModule = { ggmlMaxDims: number, ggmlTypeF16Size: number, ggmlTypeF32Size: number, + ggmlTensorOverhead: number, llamaMaxRngState: number, llamaPosSize: number, llamaSeqIdSize: number diff --git a/src/cli/commands/inspect/commands/InspectMeasureCommand.ts b/src/cli/commands/inspect/commands/InspectMeasureCommand.ts index 95ef6463..49485fe9 100644 --- a/src/cli/commands/inspect/commands/InspectMeasureCommand.ts +++ b/src/cli/commands/inspect/commands/InspectMeasureCommand.ts @@ -22,7 +22,9 @@ type InspectMeasureCommand = { minContextSize: number, maxContextSize?: number, measures: number, - printHeaderBeforeEachLayer?: boolean + printHeaderBeforeEachLayer?: boolean, + evaluateText?: string, + repeatEvaluateText?: number }; export const InspectMeasureCommand: CommandModule = { @@ -78,10 +80,24 @@ export const InspectMeasureCommand: CommandModule default: true, description: "Print header before each layer's measures", group: "Optional:" + }) + .option("evaluateText", { + alias: ["evaluate", "et"], + type: "string", + description: "Text to evaluate with the model", + group: "Optional:" + }) + .option("repeatEvaluateText", { + alias: ["repeatEvaluate", "ret"], + type: "number", + default: 1, + description: "Number of times to repeat the evaluation text before sending it for evaluation, in order to make it longer", + group: "Optional:" }); }, async handler({ - path: ggufPath, minLayers, maxLayers, minContextSize, maxContextSize, measures = 10, printHeaderBeforeEachLayer = true + path: ggufPath, minLayers, maxLayers, minContextSize, maxContextSize, measures = 10, printHeaderBeforeEachLayer = true, + evaluateText, repeatEvaluateText }: InspectMeasureCommand) { if (maxLayers === -1) maxLayers = undefined; if (maxContextSize === -1) maxContextSize = undefined; @@ -127,6 +143,9 @@ export const InspectMeasureCommand: CommandModule maxContextSize, minContextSize, tests: measures, + evaluateText: evaluateText == null + ? undefined + : evaluateText.repeat(repeatEvaluateText ?? 1), onInfo({gpuLayers, result}) { if (lastGpuLayers !== gpuLayers) { lastGpuLayers = gpuLayers; @@ -172,7 +191,7 @@ export const InspectMeasureCommand: CommandModule hadSuccessInThisProcess = true; const modelVramEstimation = ggufInsights.estimateModelResourceRequirements({gpuLayers: lastGpuLayers}).gpuVram; - const modelVramEstimationDiffBytes = (result.modelVramUsage < modelVramEstimation ? "-" : "") + + const modelVramEstimationDiffBytes = (modelVramEstimation < result.modelVramUsage ? "-" : "") + bytes(Math.abs(result.modelVramUsage - modelVramEstimation)); const modelVramEstimationDiffText = modelVramEstimationDiffBytes.padEnd(9, " ") + " " + padStartAnsi("(" + renderDiffPercentageWithColors(((modelVramEstimation / result.modelVramUsage) - 1) * 100) + ")", 9); @@ -186,7 +205,7 @@ export const InspectMeasureCommand: CommandModule const contextVramEstimationDiffBytes = (result.contextVramUsage == null || contextVramEstimation == null) ? undefined : ( - (result.contextVramUsage < contextVramEstimation ? "-" : "") + + (contextVramEstimation < result.contextVramUsage ? "-" : "") + bytes(Math.abs(result.contextVramUsage - contextVramEstimation)) ); const contextVramEstimationDiffText = ( @@ -312,7 +331,9 @@ const __filename = fileURLToPath(import.meta.url); const detectedFileName = path.basename(__filename); const expectedFileName = "InspectMeasureCommand"; -async function measureModel({modelPath, tests, initialMaxContextSize, maxContextSize, minContextSize, maxGpuLayers, minGpuLayers, onInfo}: { +async function measureModel({ + modelPath, tests, initialMaxContextSize, maxContextSize, minContextSize, maxGpuLayers, minGpuLayers, evaluateText, onInfo +}: { modelPath: string, tests: number, initialMaxContextSize?: number, @@ -320,6 +341,7 @@ async function measureModel({modelPath, tests, initialMaxContextSize, maxContext minContextSize?: number, maxGpuLayers: number, minGpuLayers?: number, + evaluateText?: string, onInfo(data: { gpuLayers: number, result: { @@ -334,6 +356,7 @@ async function measureModel({modelPath, tests, initialMaxContextSize, maxContext modelVramUsage: number, contextSize?: number, contextVramUsage?: number, + contextStateSize?: number, totalVramUsage: number } }): void @@ -414,7 +437,8 @@ async function measureModel({modelPath, tests, initialMaxContextSize, maxContext maxContextSize, minContextSize, maxGpuLayers, - minGpuLayers + minGpuLayers, + evaluateText } satisfies ParentToChildMessage); if (timeoutHandle != null) { @@ -445,6 +469,7 @@ async function measureModel({modelPath, tests, initialMaxContextSize, maxContext modelVramUsage: message.modelVramUsage, contextSize: message.contextSize, contextVramUsage: message.contextVramUsage, + contextStateSize: message.contextStateSize, totalVramUsage: message.totalVramUsage } }); @@ -499,9 +524,9 @@ async function runTestWorkerLogic() { process.send(info); } - async function testContextSizes({model, modelVramUsage, startContextSize, maxContextSize, minContextSize, tests}: { + async function testContextSizes({model, modelVramUsage, startContextSize, maxContextSize, minContextSize, tests, evaluateText}: { model: LlamaModel, modelVramUsage: number, startContextSize?: number, maxContextSize?: number, minContextSize?: number, - tests: number + tests: number, evaluateText?: string }) { const contextSizeCheckPlan = getContextSizesCheckPlan( maxContextSize != null @@ -524,6 +549,12 @@ async function runTestWorkerLogic() { const context = await model.createContext({ contextSize: currentContextSizeCheck ?? undefined }); + + if (evaluateText != null && evaluateText != "") { + const sequence = context.getSequence(); + await sequence.evaluateWithoutGeneratingNewTokens(model.tokenize(evaluateText)); + } + const postContextVramUsage = llama.getVramState().used; sendInfoBack({ @@ -532,6 +563,7 @@ async function runTestWorkerLogic() { modelVramUsage, contextSize: context.contextSize, contextVramUsage: postContextVramUsage - preContextVramUsage, + contextStateSize: context.stateSize, totalVramUsage: postContextVramUsage }); currentContextSizeCheck = context.contextSize; @@ -557,8 +589,9 @@ async function runTestWorkerLogic() { } } - async function testWithGpuLayers({modelPath, gpuLayers, tests, startContextSize, maxContextSize, minContextSize}: { - modelPath: string, gpuLayers: number, tests: number, startContextSize?: number, maxContextSize?: number, minContextSize?: number + async function testWithGpuLayers({modelPath, gpuLayers, tests, startContextSize, maxContextSize, minContextSize, evaluateText}: { + modelPath: string, gpuLayers: number, tests: number, startContextSize?: number, maxContextSize?: number, minContextSize?: number, + evaluateText?: string }) { try { const preModelVramUsage = llama.getVramState().used; @@ -581,7 +614,8 @@ async function runTestWorkerLogic() { startContextSize, maxContextSize, minContextSize, - tests + tests, + evaluateText }); await model.dispose(); @@ -605,7 +639,8 @@ async function runTestWorkerLogic() { ? message.initialMaxContextSize : undefined, maxContextSize: message.maxContextSize, - minContextSize: message.minContextSize + minContextSize: message.minContextSize, + evaluateText: message.evaluateText }); } @@ -688,7 +723,8 @@ type ParentToChildMessage = { minGpuLayers?: number, initialMaxContextSize?: number, maxContextSize?: number, - minContextSize?: number + minContextSize?: number, + evaluateText?: string } | { type: "exit" }; @@ -701,6 +737,7 @@ type ChildToParentMessage = { modelVramUsage: number, contextSize?: number, contextVramUsage?: number, + contextStateSize?: number, totalVramUsage: number } | { type: "error", diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 41c634c5..5dfc4b81 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -86,10 +86,13 @@ export class GgufInsights { const sizeTBytes = 8; // sizeof(size_t) const floatBytes = 4; // sizeof(float) const uint32TBytes = 4; // sizeof(uint32_t) + const int32TBytes = 4; // sizeof(int32_t) // source: `llama_get_state_size` in `llama.cpp` const sRngSize = sizeTBytes; const sRng = this._llama._consts.llamaMaxRngState; + const sNOutputs = sizeTBytes; + const sNOutputPos = batchSize * int32TBytes; const sLogitsSize = sizeTBytes; const sLogits = logitsSize * floatBytes; const sEmbeddingSize = sizeTBytes; @@ -98,7 +101,7 @@ export class GgufInsights { const sKvHead = uint32TBytes; const sKvSize = uint32TBytes; const sKvUsed = uint32TBytes; - // const sKv = this._estimateKvByteSize(contextSize); + const sKv = 2 * int32TBytes * modelGpuLayers * this._llama._consts.ggmlTensorOverhead; const sKvCell = this._llama._consts.llamaPosSize + sizeTBytes + this._llama._consts.llamaSeqIdSize; const kvSelfLength = this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.mamba ? Math.max(1, sequences) @@ -108,6 +111,8 @@ export class GgufInsights { const overheadMemory = ( sRngSize + sRng + + sNOutputs + + sNOutputPos + sLogitsSize + sLogits + sEmbeddingSize + @@ -116,19 +121,77 @@ export class GgufInsights { sKvHead + sKvSize + sKvUsed + + sKv + sKvCells ); // Estimates the memory allocated by `ggml_backend_sched_reserve` in `llama_new_context_with_model` in `llama.cpp`. // If you read this line and have better insights on how to estimate this memory, please open a PR to improve it :) const estimateGraphOverheadMemory = () => { + const s1MB = Math.pow(1024, 2); const tensorInfo = this.ggufFileInfo.tensorInfo ?? []; - const totalDimensions = tensorInfo.length === 0 + let defaultCalculationAdjustment = 0; + + if (batchSize == null) + return 0; + + if (this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.llama) { + const expertCount = this.ggufFileInfo.architectureMetadata.expert_count ?? 0; + const headCount = this.ggufFileInfo.architectureMetadata.attention?.head_count ?? 0; + const embeddingLength = llmData.embedding_length ?? 0; + + if (expertCount > 0) { + const expertsUsedCount = this.ggufFileInfo.architectureMetadata.expert_used_count ?? 2; + + return int32TBytes * batchSize * (((expertsUsedCount + 1) * embeddingLength) + (actualContextSize * headCount)); + } + + return int32TBytes * batchSize * (embeddingLength + (actualContextSize * headCount)); + } else if (this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.qwen2) { + if (modelGpuLayers === this.totalLayers) { + defaultCalculationAdjustment -= (s1MB * 340) * ( + this.trainContextSize == null + ? 1 + : actualContextSize / this.trainContextSize + ); + } else { + defaultCalculationAdjustment -= (s1MB * 250) + ( + (s1MB * 50) * ( + this.trainContextSize == null + ? 1 + : actualContextSize / this.trainContextSize + ) + ); + } + } else if (this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.gemma) { + // only works properly when all layers are on the GPU, which is why it's commented out: + // return int32TBytes * batchSize * ((llmData.embedding_length ?? 0)); + + if (modelGpuLayers === this.totalLayers) { + defaultCalculationAdjustment += (s1MB * 40) - ( + (s1MB * 270) * ( + this.trainContextSize == null + ? 1 + : actualContextSize / this.trainContextSize + ) + ); + } else { + defaultCalculationAdjustment += -(s1MB * 550) + ( + (s1MB * 150) * ( + this.trainContextSize == null + ? 1 + : Math.max(0, (1 - (actualContextSize / this.trainContextSize))) + ) + ); + } + } + + const totalElements = tensorInfo.length === 0 ? this.totalLayers * ( ( - (this.ggufFileInfo.architectureMetadata.embedding_length ?? 0) + - (this.ggufFileInfo.architectureMetadata.feed_forward_length ?? 0) + (llmData.embedding_length ?? 0) + + (llmData.feed_forward_length ?? 0) ) / 2 ) : tensorInfo.reduce((res, tensor) => { @@ -136,7 +199,7 @@ export class GgufInsights { }, 0); // magic numbers for estimation. will be improved in the future - return totalDimensions * 77.655 * (actualContextSize / 4096); + return (totalElements * 77.655 * (actualContextSize / 4096)) + defaultCalculationAdjustment; }; const graphOverheadMemory = !includeGraphOverhead @@ -193,6 +256,19 @@ export class GgufInsights { for (const singleTensorInfo of tensorInfo) { const {layerNumber} = parseTensorName(singleTensorInfo.name); + if (gpuLayers !== this.totalLayers) { + const architecture = this.ggufFileInfo.metadata?.general?.architecture; + + if (architecture === GgufArchitectureType.qwen2 || architecture === GgufArchitectureType.gemma) { + if (layerNumber != null && layerNumber < gpuLayers) + gpuTensors.push(singleTensorInfo); + else + cpuTensors.push(singleTensorInfo); + + continue; + } + } + if (layerNumber == null || layerNumber < gpuLayers) gpuTensors.push(singleTensorInfo); else From f2e52d30f8d97ac5f41ab744a7a5a891336fd787 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 4 Apr 2024 20:39:02 +0300 Subject: [PATCH 50/52] test: update tests --- .../functionaryModelGpuLayersOptions.test.ts | 70 +++++++++---------- .../functionary/gguf/ggufInsights.test.ts | 42 +++++------ .../stableCode/parallel.test.ts | 29 ++++---- .../stableCodeModelGpuLayersOptions.test.ts | 16 ++--- 4 files changed, 79 insertions(+), 78 deletions(-) diff --git a/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts b/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts index b7af73f0..554d60f2 100644 --- a/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts +++ b/test/modelDependent/functionary/functionaryModelGpuLayersOptions.test.ts @@ -104,7 +104,7 @@ describe("functionary", () => { freeVram: s1GB * 3 }); expect(res.gpuLayers).to.eql(16); - expect(res.contextSize).to.toMatchInlineSnapshot("8037"); + expect(res.contextSize).to.toMatchInlineSnapshot("7415"); } try { resolveGpuLayers(16, { @@ -167,7 +167,7 @@ describe("functionary", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.eql(32); - expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + expect(res.contextSize).to.toMatchInlineSnapshot("11260"); } try { resolveGpuLayers(32, { @@ -216,7 +216,7 @@ describe("functionary", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + expect(res.contextSize).to.toMatchInlineSnapshot("11260"); } try { resolveGpuLayers(33, { @@ -303,7 +303,7 @@ describe("functionary", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("633"); + expect(res.contextSize).to.toMatchInlineSnapshot("561"); } { const res = resolveGpuLayers("max", { @@ -311,7 +311,7 @@ describe("functionary", () => { freeVram: s1GB * 4.4 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("2878"); + expect(res.contextSize).to.toMatchInlineSnapshot("2701"); } { const res = resolveGpuLayers("max", { @@ -319,7 +319,7 @@ describe("functionary", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("5123"); + expect(res.contextSize).to.toMatchInlineSnapshot("4840"); } }); @@ -346,15 +346,15 @@ describe("functionary", () => { freeVram: s1GB * 0.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("1"); - expect(res.contextSize).to.toMatchInlineSnapshot("7608"); + expect(res.contextSize).to.toMatchInlineSnapshot("6522"); } { const res = resolveGpuLayers("auto", { totalVram: s1GB * 6, freeVram: s1GB * 1.4 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("5"); - expect(res.contextSize).to.toMatchInlineSnapshot("7964"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("4"); + expect(res.contextSize).to.toMatchInlineSnapshot("8799"); } { const res = resolveGpuLayers("auto", { @@ -362,7 +362,7 @@ describe("functionary", () => { freeVram: s1GB * 2.4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("11"); - expect(res.contextSize).to.toMatchInlineSnapshot("9310"); + expect(res.contextSize).to.toMatchInlineSnapshot("8472"); } { const res = resolveGpuLayers("auto", { @@ -370,31 +370,31 @@ describe("functionary", () => { freeVram: s1GB * 3.1 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("16"); - expect(res.contextSize).to.toMatchInlineSnapshot("8891"); + expect(res.contextSize).to.toMatchInlineSnapshot("8209"); } { const res = resolveGpuLayers("auto", { totalVram: s1GB * 6, freeVram: s1GB * 3.3 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); - expect(res.contextSize).to.toMatchInlineSnapshot("8118"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("17"); + expect(res.contextSize).to.toMatchInlineSnapshot("8628"); } { const res = resolveGpuLayers("auto", { totalVram: s1GB * 6, freeVram: s1GB * 3.5 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("19"); - expect(res.contextSize).to.toMatchInlineSnapshot("8544"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); + expect(res.contextSize).to.toMatchInlineSnapshot("9024"); } { const res = resolveGpuLayers("auto", { totalVram: s1GB * 6, freeVram: s1GB * 3.8 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("21"); - expect(res.contextSize).to.toMatchInlineSnapshot("8590"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("20"); + expect(res.contextSize).to.toMatchInlineSnapshot("9042"); } { const res = resolveGpuLayers("auto", { @@ -402,7 +402,7 @@ describe("functionary", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); - expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + expect(res.contextSize).to.toMatchInlineSnapshot("8386"); } { const res = resolveGpuLayers("auto", { @@ -410,7 +410,7 @@ describe("functionary", () => { freeVram: s1GB * 4.3 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("24"); - expect(res.contextSize).to.toMatchInlineSnapshot("8988"); + expect(res.contextSize).to.toMatchInlineSnapshot("8434"); } { const res = resolveGpuLayers("auto", { @@ -418,7 +418,7 @@ describe("functionary", () => { freeVram: s1GB * 4.5 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("3439"); + expect(res.contextSize).to.toMatchInlineSnapshot("3235"); } { const res = resolveGpuLayers("auto", { @@ -426,7 +426,7 @@ describe("functionary", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5123"); + expect(res.contextSize).to.toMatchInlineSnapshot("4840"); } { const res = resolveGpuLayers("auto", { @@ -434,7 +434,7 @@ describe("functionary", () => { freeVram: s1GB * 5.2 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("7368"); + expect(res.contextSize).to.toMatchInlineSnapshot("6980"); } { const res = resolveGpuLayers("auto", { @@ -442,7 +442,7 @@ describe("functionary", () => { freeVram: s1GB * 5.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("10736"); + expect(res.contextSize).to.toMatchInlineSnapshot("10190"); } { const res = resolveGpuLayers("auto", { @@ -450,7 +450,7 @@ describe("functionary", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("11859"); + expect(res.contextSize).to.toMatchInlineSnapshot("11260"); } }); @@ -496,7 +496,7 @@ describe("functionary", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.eql(16); - expect(res.contextSize).to.toMatchInlineSnapshot("16575"); + expect(res.contextSize).to.toMatchInlineSnapshot("15358"); } try { resolveGpuLayers({min: 16}, { @@ -514,7 +514,7 @@ describe("functionary", () => { }); expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); - expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + expect(res.contextSize).to.toMatchInlineSnapshot("8386"); } { const res = resolveGpuLayers({min: 16, max: 24}, { @@ -524,7 +524,7 @@ describe("functionary", () => { expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.be.lte(24); expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); - expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + expect(res.contextSize).to.toMatchInlineSnapshot("8386"); } { const res = resolveGpuLayers({min: 16, max: 24}, { @@ -534,7 +534,7 @@ describe("functionary", () => { expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.be.lte(24); expect(res.gpuLayers).to.toMatchInlineSnapshot("16"); - expect(res.contextSize).to.toMatchInlineSnapshot("8037"); + expect(res.contextSize).to.toMatchInlineSnapshot("7415"); } }); @@ -557,7 +557,7 @@ describe("functionary", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("26"); - expect(res.contextSize).to.toMatchInlineSnapshot("5142"); + expect(res.contextSize).to.toMatchInlineSnapshot("4819"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -566,8 +566,8 @@ describe("functionary", () => { totalVram: s1GB * 2, freeVram: s1GB * 1 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("4"); - expect(res.contextSize).to.toMatchInlineSnapshot("4385"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("3"); + expect(res.contextSize).to.toMatchInlineSnapshot("5495"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -576,8 +576,8 @@ describe("functionary", () => { totalVram: s1GB * 6, freeVram: s1GB * 4 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("22"); - expect(res.contextSize).to.toMatchInlineSnapshot("8968"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("21"); + expect(res.contextSize).to.toMatchInlineSnapshot("9395"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -586,8 +586,8 @@ describe("functionary", () => { totalVram: s1GB * 1, freeVram: s1GB * 1 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("2"); - expect(res.contextSize).to.toMatchInlineSnapshot("8497"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("1"); + expect(res.contextSize).to.toMatchInlineSnapshot("9434"); expect(res.contextSize).to.be.gte(contextSize); } { diff --git a/test/modelDependent/functionary/gguf/ggufInsights.test.ts b/test/modelDependent/functionary/gguf/ggufInsights.test.ts index d329ab42..303c9d3f 100644 --- a/test/modelDependent/functionary/gguf/ggufInsights.test.ts +++ b/test/modelDependent/functionary/gguf/ggufInsights.test.ts @@ -121,7 +121,7 @@ describe("gguf", async () => { sequences: context.totalSequences, modelGpuLayers: ggufInsights.totalLayers }).gpuVram; - expect(bytes(estimatedContextVramUsage)).toMatchInlineSnapshot('"809.83MB"'); + expect(bytes(estimatedContextVramUsage)).toMatchInlineSnapshot('"854.74MB"'); expect(Math.abs(contextVramUsageDiff - estimatedContextVramUsage)).to.be.lte(s100MB); await model.dispose(); @@ -165,7 +165,7 @@ describe("gguf", async () => { batchSize: 512 }))).toMatchInlineSnapshot(` { - "cpuRam": "1.52GB", + "cpuRam": "1.6GB", "gpuVram": "0B", } `); @@ -176,7 +176,7 @@ describe("gguf", async () => { batchSize: 512 }))).toMatchInlineSnapshot(` { - "cpuRam": "809.83MB", + "cpuRam": "854.63MB", "gpuVram": "0B", } `); @@ -187,7 +187,7 @@ describe("gguf", async () => { batchSize: 512 }))).toMatchInlineSnapshot(` { - "cpuRam": "436.2MB", + "cpuRam": "462.6MB", "gpuVram": "0B", } `); @@ -198,7 +198,7 @@ describe("gguf", async () => { batchSize: 512 }))).toMatchInlineSnapshot(` { - "cpuRam": "249.39MB", + "cpuRam": "266.59MB", "gpuVram": "0B", } `); @@ -211,7 +211,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "1GB", - "gpuVram": "565.1MB", + "gpuVram": "646.7MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -222,7 +222,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "512MB", - "gpuVram": "313.83MB", + "gpuVram": "358.64MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -233,7 +233,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "256MB", - "gpuVram": "188.2MB", + "gpuVram": "214.61MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -244,7 +244,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "128MB", - "gpuVram": "125.39MB", + "gpuVram": "142.59MB", } `); @@ -256,7 +256,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "544MB", - "gpuVram": "1.02GB", + "gpuVram": "1.1GB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -267,7 +267,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "272MB", - "gpuVram": "553.83MB", + "gpuVram": "598.68MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -278,7 +278,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "136MB", - "gpuVram": "308.2MB", + "gpuVram": "334.65MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -289,7 +289,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "68MB", - "gpuVram": "185.39MB", + "gpuVram": "202.64MB", } `); @@ -301,7 +301,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "32MB", - "gpuVram": "1.52GB", + "gpuVram": "1.6GB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -312,7 +312,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "16MB", - "gpuVram": "809.83MB", + "gpuVram": "854.73MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -323,7 +323,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "8MB", - "gpuVram": "436.2MB", + "gpuVram": "462.7MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -334,7 +334,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "4MB", - "gpuVram": "249.39MB", + "gpuVram": "266.69MB", } `); @@ -346,7 +346,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "0B", - "gpuVram": "1.52GB", + "gpuVram": "1.6GB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -357,7 +357,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "0B", - "gpuVram": "809.83MB", + "gpuVram": "854.74MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -368,7 +368,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "0B", - "gpuVram": "436.2MB", + "gpuVram": "462.7MB", } `); expect(makeEstimationReadable(ggufInsights.estimateContextResourceRequirements({ @@ -379,7 +379,7 @@ describe("gguf", async () => { }))).toMatchInlineSnapshot(` { "cpuRam": "0B", - "gpuVram": "249.39MB", + "gpuVram": "266.69MB", } `); }); diff --git a/test/modelDependent/stableCode/parallel.test.ts b/test/modelDependent/stableCode/parallel.test.ts index f600c8fd..8f255b1e 100644 --- a/test/modelDependent/stableCode/parallel.test.ts +++ b/test/modelDependent/stableCode/parallel.test.ts @@ -95,8 +95,8 @@ describe("stableCode", () => { const expectedFullCompletion = " " + range(4, 100).join(", "); const expectedFullCompletion2 = " " + range(96, 1).join(", "); - expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); - expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); + expect(res).to.eql(expectedFullCompletion.slice(0, res.length)); + expect(res2).to.eql(expectedFullCompletion2.slice(0, res2.length)); }); test("can use multiple contexts in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { @@ -136,11 +136,11 @@ describe("stableCode", () => { const expectedFullCompletion = " " + range(4, 100).join(", "); const expectedFullCompletion2 = " " + range(96, 1).join(", "); - expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); - expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); + expect(res).to.eql(expectedFullCompletion.slice(0, res.length)); + expect(res2).to.eql(expectedFullCompletion2.slice(0, res2.length)); }); - test("can use multiple context sequences in parallel", {timeout: 1000 * 60 * 60 * 2}, async () => { + test("can use multiple context sequences in parallel", {timeout: 1000 * 60 * 60 * 2, retry: 4}, async () => { const modelPath = await getModelFile("stable-code-3b.Q5_K_M.gguf"); const llama = await getTestLlama(); @@ -149,7 +149,8 @@ describe("stableCode", () => { }); const context = await model.createContext({ contextSize: 4096, - sequences: 2 + sequences: 2, + seed: 0 }); const completion = new LlamaCompletion({ contextSequence: context.getSequence() @@ -158,11 +159,11 @@ describe("stableCode", () => { contextSequence: context.getSequence() }); - const resPromise = completion.generateCompletion("const arrayFromOneToHundred = [1, 2, 3,", { - maxTokens: 50 + const resPromise = completion.generateCompletion("const arrayFromOneToHundred = [1, 2, 3", { + maxTokens: 40 }); - const resPromise2 = completion2.generateCompletion("const arrayFromOneHundredToOne = [100, 99, 98, 97,", { - maxTokens: 50 + const resPromise2 = completion2.generateCompletion("const arrayFromOneHundredToOne = [100, 99, 98, 97, 96", { + maxTokens: 40 }); const [ @@ -173,10 +174,10 @@ describe("stableCode", () => { resPromise2 ]); - const expectedFullCompletion = " " + range(4, 100).join(", "); - const expectedFullCompletion2 = " " + range(96, 1).join(", "); - expect(expectedFullCompletion.slice(0, res.length)).to.eql(res); - expect(expectedFullCompletion2.slice(0, res2.length)).to.eql(res2); + const expectedFullCompletion = ", " + range(4, 100).join(", "); + const expectedFullCompletion2 = ", " + range(95, 1).join(", "); + expect(res).to.eql(expectedFullCompletion.slice(0, res.length)); + expect(res2).to.eql(expectedFullCompletion2.slice(0, res2.length)); }); }); }); diff --git a/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts index 6c6b091f..94bbc011 100644 --- a/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts +++ b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts @@ -303,7 +303,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + expect(res.contextSize).to.toMatchInlineSnapshot("5852"); } { const res = resolveGpuLayers("max", { @@ -319,7 +319,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("8138"); + expect(res.contextSize).to.toMatchInlineSnapshot("8137"); } }); @@ -402,7 +402,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + expect(res.contextSize).to.toMatchInlineSnapshot("5852"); } { const res = resolveGpuLayers("auto", { @@ -410,7 +410,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.3 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("6710"); + expect(res.contextSize).to.toMatchInlineSnapshot("6709"); } { const res = resolveGpuLayers("auto", { @@ -426,7 +426,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("8138"); + expect(res.contextSize).to.toMatchInlineSnapshot("8137"); } { const res = resolveGpuLayers("auto", { @@ -514,7 +514,7 @@ describe("stableCode", () => { }); expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + expect(res.contextSize).to.toMatchInlineSnapshot("5852"); } { const res = resolveGpuLayers({min: 16, max: 24}, { @@ -534,7 +534,7 @@ describe("stableCode", () => { expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.be.lte(24); expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); - expect(res.contextSize).to.toMatchInlineSnapshot("8208"); + expect(res.contextSize).to.toMatchInlineSnapshot("8207"); } }); @@ -557,7 +557,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5853"); + expect(res.contextSize).to.toMatchInlineSnapshot("5852"); expect(res.contextSize).to.be.gte(contextSize); } { From 67f89c6e5242ac1b664dbea1bbfcd059ed4fc6d4 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 4 Apr 2024 21:32:08 +0300 Subject: [PATCH 51/52] feat: improve VRAM consumption estimations --- src/gguf/GgufInsights.ts | 22 +++++++ .../stableCodeModelGpuLayersOptions.test.ts | 64 +++++++++---------- 2 files changed, 54 insertions(+), 32 deletions(-) diff --git a/src/gguf/GgufInsights.ts b/src/gguf/GgufInsights.ts index 5dfc4b81..72ff9655 100644 --- a/src/gguf/GgufInsights.ts +++ b/src/gguf/GgufInsights.ts @@ -185,6 +185,28 @@ export class GgufInsights { ) ); } + } else if (this.ggufFileInfo.metadata.general.architecture === GgufArchitectureType.stablelm) { + const headCount = this.ggufFileInfo.architectureMetadata.attention?.head_count ?? 0; + + return (int32TBytes * batchSize * actualContextSize * headCount) - (50 * s1MB); + + // if (modelGpuLayers === this.totalLayers) { + // defaultCalculationAdjustment += -(s1MB * 20) + ( + // (s1MB * 250) * ( + // this.trainContextSize == null + // ? 1 + // : actualContextSize / this.trainContextSize + // ) + // ); + // } else { + // defaultCalculationAdjustment += -(s1MB * 40) + ( + // (s1MB * 300) * ( + // this.trainContextSize == null + // ? 1 + // : actualContextSize / this.trainContextSize + // ) + // ); + // } } const totalElements = tensorInfo.length === 0 diff --git a/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts index 94bbc011..09da33c1 100644 --- a/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts +++ b/test/modelDependent/stableCode/stableCodeModelGpuLayersOptions.test.ts @@ -104,7 +104,7 @@ describe("stableCode", () => { freeVram: s1GB * 3 }); expect(res.gpuLayers).to.eql(16); - expect(res.contextSize).to.toMatchInlineSnapshot("9530"); + expect(res.contextSize).to.toMatchInlineSnapshot("8652"); } try { resolveGpuLayers(16, { @@ -167,7 +167,7 @@ describe("stableCode", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.eql(32); - expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + expect(res.contextSize).to.toMatchInlineSnapshot("10905"); } try { resolveGpuLayers(32, { @@ -216,7 +216,7 @@ describe("stableCode", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + expect(res.contextSize).to.toMatchInlineSnapshot("10905"); } try { resolveGpuLayers(33, { @@ -303,7 +303,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("5852"); + expect(res.contextSize).to.toMatchInlineSnapshot("5583"); } { const res = resolveGpuLayers("max", { @@ -311,7 +311,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.4 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("6995"); + expect(res.contextSize).to.toMatchInlineSnapshot("6647"); } { const res = resolveGpuLayers("max", { @@ -319,7 +319,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.eql(33); - expect(res.contextSize).to.toMatchInlineSnapshot("8137"); + expect(res.contextSize).to.toMatchInlineSnapshot("7712"); } }); @@ -346,15 +346,15 @@ describe("stableCode", () => { freeVram: s1GB * 0.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("4"); - expect(res.contextSize).to.toMatchInlineSnapshot("3742"); + expect(res.contextSize).to.toMatchInlineSnapshot("3308"); } { const res = resolveGpuLayers("auto", { totalVram: s1GB * 6, freeVram: s1GB * 1.4 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("10"); - expect(res.contextSize).to.toMatchInlineSnapshot("4249"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("9"); + expect(res.contextSize).to.toMatchInlineSnapshot("4467"); } { const res = resolveGpuLayers("auto", { @@ -362,7 +362,7 @@ describe("stableCode", () => { freeVram: s1GB * 2.4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("1282"); + expect(res.contextSize).to.toMatchInlineSnapshot("1325"); } { const res = resolveGpuLayers("auto", { @@ -370,7 +370,7 @@ describe("stableCode", () => { freeVram: s1GB * 3.1 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("3282"); + expect(res.contextSize).to.toMatchInlineSnapshot("3187"); } { const res = resolveGpuLayers("auto", { @@ -378,7 +378,7 @@ describe("stableCode", () => { freeVram: s1GB * 3.3 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("3853"); + expect(res.contextSize).to.toMatchInlineSnapshot("3720"); } { const res = resolveGpuLayers("auto", { @@ -386,7 +386,7 @@ describe("stableCode", () => { freeVram: s1GB * 3.5 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("4424"); + expect(res.contextSize).to.toMatchInlineSnapshot("4252"); } { const res = resolveGpuLayers("auto", { @@ -394,7 +394,7 @@ describe("stableCode", () => { freeVram: s1GB * 3.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5281"); + expect(res.contextSize).to.toMatchInlineSnapshot("5050"); } { const res = resolveGpuLayers("auto", { @@ -402,7 +402,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5852"); + expect(res.contextSize).to.toMatchInlineSnapshot("5583"); } { const res = resolveGpuLayers("auto", { @@ -410,7 +410,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.3 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("6709"); + expect(res.contextSize).to.toMatchInlineSnapshot("6381"); } { const res = resolveGpuLayers("auto", { @@ -418,7 +418,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.5 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("7281"); + expect(res.contextSize).to.toMatchInlineSnapshot("6913"); } { const res = resolveGpuLayers("auto", { @@ -426,7 +426,7 @@ describe("stableCode", () => { freeVram: s1GB * 4.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("8137"); + expect(res.contextSize).to.toMatchInlineSnapshot("7712"); } { const res = resolveGpuLayers("auto", { @@ -434,7 +434,7 @@ describe("stableCode", () => { freeVram: s1GB * 5.2 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("9280"); + expect(res.contextSize).to.toMatchInlineSnapshot("8776"); } { const res = resolveGpuLayers("auto", { @@ -442,7 +442,7 @@ describe("stableCode", () => { freeVram: s1GB * 5.8 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("10994"); + expect(res.contextSize).to.toMatchInlineSnapshot("10373"); } { const res = resolveGpuLayers("auto", { @@ -450,7 +450,7 @@ describe("stableCode", () => { freeVram: s1GB * 6 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("11565"); + expect(res.contextSize).to.toMatchInlineSnapshot("10905"); } }); @@ -496,7 +496,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.eql(16); - expect(res.contextSize).to.toMatchInlineSnapshot("14593"); + expect(res.contextSize).to.toMatchInlineSnapshot("13133"); } try { resolveGpuLayers({min: 16}, { @@ -514,7 +514,7 @@ describe("stableCode", () => { }); expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5852"); + expect(res.contextSize).to.toMatchInlineSnapshot("5583"); } { const res = resolveGpuLayers({min: 16, max: 24}, { @@ -524,7 +524,7 @@ describe("stableCode", () => { expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.be.lte(24); expect(res.gpuLayers).to.toMatchInlineSnapshot("24"); - expect(res.contextSize).to.toMatchInlineSnapshot("9020"); + expect(res.contextSize).to.toMatchInlineSnapshot("8410"); } { const res = resolveGpuLayers({min: 16, max: 24}, { @@ -533,8 +533,8 @@ describe("stableCode", () => { }); expect(res.gpuLayers).to.be.gte(16); expect(res.gpuLayers).to.be.lte(24); - expect(res.gpuLayers).to.toMatchInlineSnapshot("18"); - expect(res.contextSize).to.toMatchInlineSnapshot("8207"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("17"); + expect(res.contextSize).to.toMatchInlineSnapshot("8079"); } }); @@ -557,7 +557,7 @@ describe("stableCode", () => { freeVram: s1GB * 4 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("33"); - expect(res.contextSize).to.toMatchInlineSnapshot("5852"); + expect(res.contextSize).to.toMatchInlineSnapshot("5583"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -567,7 +567,7 @@ describe("stableCode", () => { freeVram: s1GB * 1 }); expect(res.gpuLayers).to.toMatchInlineSnapshot("5"); - expect(res.contextSize).to.toMatchInlineSnapshot("4959"); + expect(res.contextSize).to.toMatchInlineSnapshot("4295"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -576,8 +576,8 @@ describe("stableCode", () => { totalVram: s1GB * 6, freeVram: s1GB * 4 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("25"); - expect(res.contextSize).to.toMatchInlineSnapshot("8537"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("24"); + expect(res.contextSize).to.toMatchInlineSnapshot("8410"); expect(res.contextSize).to.be.gte(contextSize); } { @@ -586,8 +586,8 @@ describe("stableCode", () => { totalVram: s1GB * 1, freeVram: s1GB * 1 }); - expect(res.gpuLayers).to.toMatchInlineSnapshot("2"); - expect(res.contextSize).to.toMatchInlineSnapshot("9618"); + expect(res.gpuLayers).to.toMatchInlineSnapshot("1"); + expect(res.contextSize).to.toMatchInlineSnapshot("8962"); expect(res.contextSize).to.be.gte(contextSize); } { From cb2f3c875e8cfa32cdd0a1802b375eed8e627731 Mon Sep 17 00:00:00 2001 From: Gilad S Date: Thu, 4 Apr 2024 22:21:49 +0300 Subject: [PATCH 52/52] feat: reserve memory for a model/context before its creation and release it afterward --- src/evaluator/LlamaContext/LlamaContext.ts | 30 ++++++++++++++++------ src/evaluator/LlamaModel.ts | 6 +++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/src/evaluator/LlamaContext/LlamaContext.ts b/src/evaluator/LlamaContext/LlamaContext.ts index 440e1c9c..53938d7a 100644 --- a/src/evaluator/LlamaContext/LlamaContext.ts +++ b/src/evaluator/LlamaContext/LlamaContext.ts @@ -551,21 +551,35 @@ export class LlamaContext { isEmbeddingContext: options._embeddings }); const batchSize = options.batchSize ?? getDefaultContextBatchSize({contextSize, sequences}); + const vramRequiredEstimate = _model.fileInsights.estimateContextResourceRequirements({ + contextSize, + sequences, + isEmbeddingContext: options._embeddings, + modelGpuLayers: _model.gpuLayers, + batchSize + }).gpuVram; const context = new LlamaContext({_model}, {...options, contextSize, batchSize, sequences}); const {createSignal} = options; + const contextCreationMemoryReservation = options.ignoreMemorySafetyChecks + ? null + : _model._llama._vramOrchestrator.reserveMemory(vramRequiredEstimate); - const contextLoaded = await context._ctx.init(); + try { + const contextLoaded = await context._ctx.init(); - if (createSignal?.aborted) { - if (contextLoaded) - await context._ctx.dispose(); + if (createSignal?.aborted) { + if (contextLoaded) + await context._ctx.dispose(); - throw createSignal.reason; - } else if (!contextLoaded) - throw new Error("Failed to create context"); + throw createSignal.reason; + } else if (!contextLoaded) + throw new Error("Failed to create context"); - return context; + return context; + } finally { + contextCreationMemoryReservation?.dispose?.(); + } } } diff --git a/src/evaluator/LlamaModel.ts b/src/evaluator/LlamaModel.ts index 0437a50f..ec569005 100644 --- a/src/evaluator/LlamaModel.ts +++ b/src/evaluator/LlamaModel.ts @@ -387,7 +387,12 @@ export class LlamaModel { llamaGpu: _llama.gpu, llamaSupportsGpuOffloading: _llama.supportsGpuOffloading }); + const vramRequiredEstimate = ggufInsights.estimateModelResourceRequirements({gpuLayers: gpuLayers}).gpuVram; + const model = new LlamaModel({...modelOptions, gpuLayers}, {_fileInfo: fileInfo, _fileInsights: ggufInsights, _llama}); + const modelCreationMemoryReservation = modelOptions.ignoreMemorySafetyChecks + ? null + : _llama._vramOrchestrator.reserveMemory(vramRequiredEstimate); function onAbort() { model._model.abortActiveModelLoad(); @@ -415,6 +420,7 @@ export class LlamaModel { return model; } finally { loadSignal?.removeEventListener("abort", onAbort); + modelCreationMemoryReservation?.dispose?.(); } } }