Skip to content

Add CentML as an inference provider #1394

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions packages/inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ Currently, we support the following providers:
- [Blackforestlabs](https://blackforestlabs.ai)
- [Cohere](https://cohere.com)
- [Cerebras](https://cerebras.ai/)
- [CentML](https://centml.ai)
- [Groq](https://groq.com)

To send requests to a third-party provider, you have to pass the `provider` parameter to the inference function. Make sure your request is authenticated with an access token.
Expand Down Expand Up @@ -91,6 +92,7 @@ Only a subset of models are supported when requesting third-party providers. You
- [Together supported models](https://huggingface.co/api/partners/together/models)
- [Cohere supported models](https://huggingface.co/api/partners/cohere/models)
- [Cerebras supported models](https://huggingface.co/api/partners/cerebras/models)
- [CentML supported models](https://huggingface.co/api/partners/centml/models)
- [Groq supported models](https://console.groq.com/docs/models)
- [HF Inference API (serverless)](https://huggingface.co/models?inference=warm&sort=trending)

Expand Down
4 changes: 4 additions & 0 deletions packages/inference/src/lib/getProviderHelper.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import * as BlackForestLabs from "../providers/black-forest-labs";
import * as Cerebras from "../providers/cerebras";
import * as CentML from "../providers/centml";
import * as Cohere from "../providers/cohere";
import * as FalAI from "../providers/fal-ai";
import * as FeatherlessAI from "../providers/featherless-ai";
Expand Down Expand Up @@ -56,6 +57,9 @@ export const PROVIDERS: Record<InferenceProvider, Partial<Record<InferenceTask,
cerebras: {
conversational: new Cerebras.CerebrasConversationalTask(),
},
centml: {
conversational: new CentML.CentMLConversationalTask(),
},
cohere: {
conversational: new Cohere.CohereConversationalTask(),
},
Expand Down
17 changes: 17 additions & 0 deletions packages/inference/src/providers/centml.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/**
* CentML provider implementation for serverless inference.
* This provider supports chat completions and text generation through CentML's serverless endpoints.
*/
import { BaseConversationalTask } from "./providerHelper";

const CENTML_API_BASE_URL = "https://api.centml.com";

export class CentMLConversationalTask extends BaseConversationalTask {
constructor() {
super("centml", CENTML_API_BASE_URL);
}

override makeRoute(): string {
return "openai/v1/chat/completions";
}
}
1 change: 1 addition & 0 deletions packages/inference/src/providers/consts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ export const HARDCODED_MODEL_INFERENCE_MAPPING: Record<
*/
"black-forest-labs": {},
cerebras: {},
centml: {},
cohere: {},
"fal-ai": {},
"featherless-ai": {},
Expand Down
1 change: 1 addition & 0 deletions packages/inference/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ export type InferenceTask = Exclude<PipelineType, "other"> | "conversational";
export const INFERENCE_PROVIDERS = [
"black-forest-labs",
"cerebras",
"centml",
"cohere",
"fal-ai",
"featherless-ai",
Expand Down
85 changes: 85 additions & 0 deletions packages/inference/test/InferenceClient.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1978,4 +1978,89 @@ describe.skip("InferenceClient", () => {
},
TIMEOUT
);
describe.concurrent(
"CentML",
() => {
const client = new InferenceClient(env.HF_CENTML_KEY ?? "dummy");

HARDCODED_MODEL_INFERENCE_MAPPING["centml"] = {
"meta-llama/Llama-3.2-3B-Instruct": {
hfModelId: "meta-llama/Llama-3.2-3B-Instruct",
providerId: "meta-llama/Llama-3.2-3B-Instruct",
status: "live",
task: "conversational",
},
};

describe("chat completions", () => {
it("basic chat completion", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "centml",
messages: [{ role: "user", content: "Complete this sentence with words, one plus one is equal " }],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("two");
}
});

it("chat completion with multiple messages", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "centml",
messages: [
{ role: "system", content: "You are a helpful assistant." },
{ role: "user", content: "What is 2+2?" },
{ role: "assistant", content: "The answer is 4." },
{ role: "user", content: "What is 3+3?" },
],
});
if (res.choices && res.choices.length > 0) {
const completion = res.choices[0].message?.content;
expect(completion).toContain("6");
}
});

it("chat completion with parameters", async () => {
const res = await client.chatCompletion({
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "centml",
messages: [{ role: "user", content: "Write a short poem about AI" }],
temperature: 0.7,
max_tokens: 100,
top_p: 0.9,
});
if (res.choices && res.choices.length > 0 && res.choices[0].message?.content) {
const completion = res.choices[0].message.content;
expect(completion).toBeTruthy();
expect(completion.length).toBeGreaterThan(0);
}
});

it("chat completion stream", async () => {
const stream = client.chatCompletionStream({
model: "meta-llama/Llama-3.2-3B-Instruct",
provider: "centml",
messages: [{ role: "user", content: "Say 'this is a test'" }],
stream: true,
}) as AsyncGenerator<ChatCompletionStreamOutput>;

let fullResponse = "";
for await (const chunk of stream) {
if (chunk.choices && chunk.choices.length > 0) {
const content = chunk.choices[0].delta?.content;
if (content) {
fullResponse += content;
}
}
}

expect(fullResponse).toBeTruthy();
expect(fullResponse.length).toBeGreaterThan(0);
});
});
},
TIMEOUT
);
});