diff --git a/packages/inference/README.md b/packages/inference/README.md index 55cff9429..2da657a93 100644 --- a/packages/inference/README.md +++ b/packages/inference/README.md @@ -61,6 +61,7 @@ Currently, we support the following providers: - [Blackforestlabs](https://blackforestlabs.ai) - [Cohere](https://cohere.com) - [Cerebras](https://cerebras.ai/) +- [CentML](https://centml.ai) - [Groq](https://groq.com) To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token. @@ -91,6 +92,7 @@ Only a subset of models are supported when requesting third-party providers. You - [Together supported models](https://huggingface.co/api/partners/together/models) - [Cohere supported models](https://huggingface.co/api/partners/cohere/models) - [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models) +- [CentML supported models](https://huggingface.co/api/partners/centml/models) - [Groq supported models](https://console.groq.com/docs/models) - [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending) diff --git a/packages/inference/src/lib/getProviderHelper.ts b/packages/inference/src/lib/getProviderHelper.ts index 2a1ce00fe..d86603a69 100644 --- a/packages/inference/src/lib/getProviderHelper.ts +++ b/packages/inference/src/lib/getProviderHelper.ts @@ -1,5 +1,6 @@ import * as BlackForestLabs from "../providers/black-forest-labs"; import * as Cerebras from "../providers/cerebras"; +import * as CentML from "../providers/centml"; import * as Cohere from "../providers/cohere"; import * as FalAI from "../providers/fal-ai"; import * as FeatherlessAI from "../providers/featherless-ai"; @@ -56,6 +57,9 @@ export const PROVIDERS: Record | "conversational"; export const INFERENCE_PROVIDERS = [ "black-forest-labs", "cerebras", + "centml", "cohere", "fal-ai", "featherless-ai", diff --git a/packages/inference/test/InferenceClient.spec.ts b/packages/inference/test/InferenceClient.spec.ts index f8ab02d1d..7b4a1035c 100644 --- a/packages/inference/test/InferenceClient.spec.ts +++ b/packages/inference/test/InferenceClient.spec.ts @@ -1978,4 +1978,89 @@ describe.skip("InferenceClient", () => { }, TIMEOUT ); + describe.concurrent( + "CentML", + () => { + const client = new InferenceClient(env.HF_CENTML_KEY ?? "dummy"); + + HARDCODED_MODEL_INFERENCE_MAPPING["centml"] = { + "meta-llama/Llama-3.2-3B-Instruct": { + hfModelId: "meta-llama/Llama-3.2-3B-Instruct", + providerId: "meta-llama/Llama-3.2-3B-Instruct", + status: "live", + task: "conversational", + }, + }; + + describe("chat completions", () => { + it("basic chat completion", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("two"); + } + }); + + it("chat completion with multiple messages", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [ + { role: "system", content: "You are a helpful assistant." }, + { role: "user", content: "What is 2+2?" }, + { role: "assistant", content: "The answer is 4." }, + { role: "user", content: "What is 3+3?" }, + ], + }); + if (res.choices && res.choices.length > 0) { + const completion = res.choices[0].message?.content; + expect(completion).toContain("6"); + } + }); + + it("chat completion with parameters", async () => { + const res = await client.chatCompletion({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Write a short poem about AI" }], + temperature: 0.7, + max_tokens: 100, + top_p: 0.9, + }); + if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) { + const completion = res.choices[0].message.content; + expect(completion).toBeTruthy(); + expect(completion.length).toBeGreaterThan(0); + } + }); + + it("chat completion stream", async () => { + const stream = client.chatCompletionStream({ + model: "meta-llama/Llama-3.2-3B-Instruct", + provider: "centml", + messages: [{ role: "user", content: "Say 'this is a test'" }], + stream: true, + }) as AsyncGenerator; + + let fullResponse = ""; + for await (const chunk of stream) { + if (chunk.choices && chunk.choices.length > 0) { + const content = chunk.choices[0].delta?.content; + if (content) { + fullResponse += content; + } + } + } + + expect(fullResponse).toBeTruthy(); + expect(fullResponse.length).toBeGreaterThan(0); + }); + }); + }, + TIMEOUT + ); });