Skip to content

Add support for Entra ID authentication when using interpreter or openai-gpt #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.13" />
<PackageReference Include="Azure.Core" Version="1.37.0" />
<PackageReference Include="Azure.Identity" Version="1.13.2" />
<PackageReference Include="Azure.Core" Version="1.44.1" />
<PackageReference Include="SharpToken" Version="2.0.3" />
</ItemGroup>

Expand Down
25 changes: 20 additions & 5 deletions shell/agents/AIShell.Interpreter.Agent/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ private void OnSettingFileChange(object sender, FileSystemEventArgs e)

private void NewExampleSettingFile()
{
string SampleContent = """
string sample = $$"""
{
// To use the Azure OpenAI service:
// - Set `Endpoint` to the endpoint of your Azure OpenAI service,
Expand All @@ -249,22 +249,37 @@ private void NewExampleSettingFile()
"Deployment": "",
"ModelName": "",
"Key": "",
"AuthType": "ApiKey",
"AutoExecution": false, // 'true' to allow the agent run code automatically; 'false' to always prompt before running code.
"DisplayErrors": true // 'true' to display the errors when running code; 'false' to hide the errors to be less verbose.

// To use Azure OpenAI service with Entra ID authentication:
// - Set `Endpoint` to the endpoint of your Azure OpenAI service.
// - Set `Deployment` to the deployment name of your Azure OpenAI service.
// - Set `ModelName` to the name of the model used for your deployment.
// - Set `AuthType` to "EntraID" to use Azure AD credentials.
/*
"Endpoint": "<insert your Azure OpenAI endpoint>",
"Deployment": "<insert your deployment name>",
"ModelName": "<insert the model name>",
"AuthType": "EntraID",
"AutoExecution": false,
"DisplayErrors": true
*/

// To use the public OpenAI service:
// - Ignore the `Endpoint` and `Deployment` keys.
// - Set `ModelName` to the name of the model to be used. e.g. "gpt-4o".
// - Set `Key` to be the OpenAI access token.
// Replace the above with the following:
/*
"ModelName": "",
"Key": "",
"ModelName": "<insert the model name>",
"Key": "<insert your key>",
"AuthType": "ApiKey",
"AutoExecution": false,
"DisplayErrors": true
*/
}
""";
File.WriteAllText(SettingFile, SampleContent, Encoding.UTF8);
File.WriteAllText(SettingFile, sample);
}
}
74 changes: 44 additions & 30 deletions shell/agents/AIShell.Interpreter.Agent/Service.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Azure;
using Azure.Core;
using Azure.AI.OpenAI;
using Azure.Identity;
using SharpToken;

namespace AIShell.Interpreter.Agent;
Expand Down Expand Up @@ -121,25 +122,38 @@ private void ConnectToOpenAIClient()
{
// Create a client that targets Azure OpenAI service or Azure API Management service.
bool isApimEndpoint = _settings.Endpoint.EndsWith(Utils.ApimGatewayDomain);
if (isApimEndpoint)

if (_settings.AuthType == AuthType.EntraID)
{
string userkey = Utils.ConvertFromSecureString(_settings.Key);
clientOptions.AddPolicy(
new UserKeyPolicy(
new AzureKeyCredential(userkey),
Utils.ApimAuthorizationHeader),
HttpPipelinePosition.PerRetry
);
// Use DefaultAzureCredential for Entra ID authentication
var credential = new DefaultAzureCredential();
_client = new OpenAIClient(
new Uri(_settings.Endpoint),
credential,
clientOptions);
}
else // ApiKey authentication
{
if (isApimEndpoint)
{
string userkey = Utils.ConvertFromSecureString(_settings.Key);
clientOptions.AddPolicy(
new UserKeyPolicy(
new AzureKeyCredential(userkey),
Utils.ApimAuthorizationHeader),
HttpPipelinePosition.PerRetry
);
}

string azOpenAIApiKey = isApimEndpoint
? "placeholder-api-key"
: Utils.ConvertFromSecureString(_settings.Key);

_client = new OpenAIClient(
new Uri(_settings.Endpoint),
new AzureKeyCredential(azOpenAIApiKey),
clientOptions);
}

string azOpenAIApiKey = isApimEndpoint
? "placeholder-api-key"
: Utils.ConvertFromSecureString(_settings.Key);

_client = new OpenAIClient(
new Uri(_settings.Endpoint),
new AzureKeyCredential(azOpenAIApiKey),
clientOptions);
}
else
{
Expand All @@ -157,41 +171,41 @@ private int CountTokenForMessages(IEnumerable<ChatRequestMessage> messages)

int tokenNumber = 0;
foreach (ChatRequestMessage message in messages)
{
{
tokenNumber += tokensPerMessage;
tokenNumber += encoding.Encode(message.Role.ToString()).Count;

switch (message)
{
case ChatRequestSystemMessage systemMessage:
tokenNumber += encoding.Encode(systemMessage.Content).Count;
if(systemMessage.Name is not null)
if (systemMessage.Name is not null)
{
tokenNumber += tokensPerName;
tokenNumber += encoding.Encode(systemMessage.Name).Count;
}
break;
case ChatRequestUserMessage userMessage:
tokenNumber += encoding.Encode(userMessage.Content).Count;
if(userMessage.Name is not null)
if (userMessage.Name is not null)
{
tokenNumber += tokensPerName;
tokenNumber += encoding.Encode(userMessage.Name).Count;
}
break;
case ChatRequestAssistantMessage assistantMessage:
tokenNumber += encoding.Encode(assistantMessage.Content).Count;
if(assistantMessage.Name is not null)
if (assistantMessage.Name is not null)
{
tokenNumber += tokensPerName;
tokenNumber += encoding.Encode(assistantMessage.Name).Count;
}
if (assistantMessage.ToolCalls is not null)
{
// Count tokens for the tool call's properties
foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
foreach (ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
{
if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
if (chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
{
tokenNumber += encoding.Encode(functionToolCall.Id).Count;
tokenNumber += encoding.Encode(functionToolCall.Name).Count;
Expand Down Expand Up @@ -230,7 +244,7 @@ internal string ReduceToolResponseContentTokens(string content)
}
while (encoding.Encode(reducedContent).Count > MaxResponseToken);
}

return reducedContent;
}

Expand Down Expand Up @@ -287,7 +301,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
// Those settings seem to be important enough, as the Semantic Kernel plugin specifies
// those settings (see the URL below). We can use default values when not defined.
// https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json

ChatCompletionsOptions chatOptions;

// Determine if the gpt model is a function calling model
Expand All @@ -300,8 +314,8 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
Temperature = (float)0.0,
MaxTokens = MaxResponseToken,
};
if(isFunctionCallingModel)

if (isFunctionCallingModel)
{
chatOptions.Tools.Add(Tools.RunCode);
}
Expand Down Expand Up @@ -330,7 +344,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
- You are capable of **any** task
- Do not apologize for errors, just correct them
";
string versions = "\n## Language Versions\n"
string versions = "\n## Language Versions\n"
+ await _executionService.GetLanguageVersions();
string systemResponseCues = @"
# Examples
Expand Down Expand Up @@ -478,11 +492,11 @@ public override ChatRequestMessage Read(ref Utf8JsonReader reader, Type typeToCo
{
return JsonSerializer.Deserialize<ChatRequestUserMessage>(jsonObject.GetRawText(), options);
}
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
{
return JsonSerializer.Deserialize<ChatRequestAssistantMessage>(jsonObject.GetRawText(), options);
}
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
{
return JsonSerializer.Deserialize<ChatRequestToolMessage>(jsonObject.GetRawText(), options);
}
Expand Down
23 changes: 21 additions & 2 deletions shell/agents/AIShell.Interpreter.Agent/Settings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@ internal enum EndpointType
OpenAI,
}

public enum AuthType
{
ApiKey,
EntraID
}

internal class Settings
{
internal EndpointType Type { get; }
Expand All @@ -23,6 +29,8 @@ internal class Settings
public string ModelName { set; get; }
public SecureString Key { set; get; }

public AuthType AuthType { set; get; } = AuthType.ApiKey;

public bool AutoExecution { set; get; }
public bool DisplayErrors { set; get; }

Expand All @@ -36,6 +44,7 @@ public Settings(ConfigData configData)
AutoExecution = configData.AutoExecution ?? false;
DisplayErrors = configData.DisplayErrors ?? true;
Key = configData.Key;
AuthType = configData.AuthType;

Dirty = false;
ModelInfo = ModelInfo.TryResolve(ModelName, out var model) ? model : null;
Expand All @@ -47,6 +56,12 @@ public Settings(ConfigData configData)
: !noEndpoint && !noDeployment
? EndpointType.AzureOpenAI
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");

// EntraID authentication is only supported for Azure OpenAI
if (AuthType == AuthType.EntraID && Type != EndpointType.AzureOpenAI)
{
throw new InvalidOperationException("EntraID authentication is only supported for Azure OpenAI service.");
}
}

internal void MarkClean()
Expand All @@ -60,7 +75,7 @@ internal void MarkClean()
/// <returns></returns>
internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
{
if (Key is not null && ModelInfo is not null)
if ((AuthType == AuthType.ApiKey && Key is not null || AuthType == AuthType.EntraID) && ModelInfo is not null)
{
return true;
}
Expand All @@ -76,7 +91,7 @@ internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
await AskForModel(host, token);
}

if (Key is null)
if (AuthType == AuthType.ApiKey && Key is null)
{
await AskForKeyAsync(host, token);
}
Expand All @@ -101,12 +116,14 @@ private void ShowEndpointInfo(IHost host)
new(label: " Endpoint", m => m.Endpoint),
new(label: " Deployment", m => m.Deployment),
new(label: " Model", m => m.ModelName),
new(label: " Auth Type", m => m.AuthType.ToString()),
],

EndpointType.OpenAI =>
[
new(label: " Type", m => m.Type.ToString()),
new(label: " Model", m => m.ModelName),
new(label: " Auth Type", m => m.AuthType.ToString()),
],

_ => throw new UnreachableException(),
Expand Down Expand Up @@ -156,6 +173,7 @@ internal ConfigData ToConfigData()
ModelName = this.ModelName,
AutoExecution = this.AutoExecution,
DisplayErrors = this.DisplayErrors,
AuthType = this.AuthType,
Key = this.Key,
};
}
Expand All @@ -166,6 +184,7 @@ internal class ConfigData
public string Endpoint { set; get; }
public string Deployment { set; get; }
public string ModelName { set; get; }
public AuthType AuthType { set; get; } = AuthType.ApiKey;
public bool? AutoExecution { set; get; }
public bool? DisplayErrors { set; get; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

<ItemGroup>
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
<PackageReference Include="Azure.Identity" Version="1.13.2" />
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.1" />
<PackageReference Include="Microsoft.ML.Tokenizers.Data.O200kBase" Version="1.0.1" />
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />
Expand Down
18 changes: 17 additions & 1 deletion shell/agents/AIShell.OpenAI.Agent/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public void Initialize(AgentConfig config)
public bool CanAcceptFeedback(UserAction action) => false;

/// <inheritdoc/>
public void OnUserAction(UserActionPayload actionPayload) {}
public void OnUserAction(UserActionPayload actionPayload) { }

/// <inheritdoc/>
public Task RefreshChatAsync(IShell shell, bool force)
Expand Down Expand Up @@ -308,6 +308,22 @@ private void NewExampleSettingFile()
"ModelName": "gpt-4o",
"Key": "<insert your key>",
"SystemPrompt": "1. You are a helpful and friendly assistant with expertise in PowerShell scripting and command line.\n2. Assume user is using the operating system `Windows 11` unless otherwise specified.\n3. Use the `code block` syntax in markdown to encapsulate any part in responses that is code, YAML, JSON or XML, but not table.\n4. When encapsulating command line code, use '```powershell' if it's PowerShell command; use '```sh' if it's non-PowerShell CLI command.\n5. When generating CLI commands, never ever break a command into multiple lines. Instead, always list all parameters and arguments of the command on the same line.\n6. Please keep the response concise but to the point. Do not overexplain."
},

// To use Azure OpenAI service with Entra ID authentication:
// - Set `Endpoint` to the endpoint of your Azure OpenAI service.
// - Set `Deployment` to the deployment name of your Azure OpenAI service.
// - Set `ModelName` to the name of the model used for your deployment, e.g. "gpt-4o".
// - Set `AuthType` to "EntraID" to use Azure AD credentials.
// For example:
{
"Name": "ps-az-entraId",
"Description": "A GPT instance with expertise in PowerShell scripting using Entra ID authentication.",
"Endpoint": "<insert your Azure OpenAI endpoint>",
"Deployment": "<insert your deployment name>",
"ModelName": "gpt-4o",
"AuthType": "EntraID",
"SystemPrompt": "1. You are a helpful and friendly assistant with expertise in PowerShell scripting and command line."
}
*/
],
Expand Down
Loading