C# Bindings - improved logging (#714)

* added optional support for .NET logging

* bump version and add missing alpha suffix

* avoid creating additional namespace for extensions

* prefer NullLogger/NullLoggerFactory over null-conditional ILogger to avoid errors

---------

Signed-off-by: mvenditto <venditto.matteo@gmail.com>
pull/913/head
mvenditto 1 year ago committed by GitHub
parent d15ff7de40
commit fc5537badf

@ -5,7 +5,7 @@
<Company></Company> <Company></Company>
<Copyright></Copyright> <Copyright></Copyright>
<NeutralLanguage>en-US</NeutralLanguage> <NeutralLanguage>en-US</NeutralLanguage>
<Version>0.6.0</Version> <Version>0.6.1-alpha</Version>
<VersionSuffix>$(VersionSuffix)</VersionSuffix> <VersionSuffix>$(VersionSuffix)</VersionSuffix>
<Version Condition=" '$(VersionSuffix)' != '' ">$(Version)$(VersionSuffix)</Version> <Version Condition=" '$(VersionSuffix)' != '' ">$(Version)$(VersionSuffix)</Version>
<TreatWarningsAsErrors>true</TreatWarningsAsErrors> <TreatWarningsAsErrors>true</TreatWarningsAsErrors>

@ -1,4 +1,7 @@
namespace Gpt4All.Bindings; using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All.Bindings;
/// <summary> /// <summary>
/// Arguments for the response processing callback /// Arguments for the response processing callback
@ -40,14 +43,16 @@ public class LLModel : ILLModel
{ {
protected readonly IntPtr _handle; protected readonly IntPtr _handle;
private readonly ModelType _modelType; private readonly ModelType _modelType;
private readonly ILogger _logger;
private bool _disposed; private bool _disposed;
public ModelType ModelType => _modelType; public ModelType ModelType => _modelType;
internal LLModel(IntPtr handle, ModelType modelType) internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
_handle = handle; _handle = handle;
_modelType = modelType; _modelType = modelType;
_logger = logger ?? NullLogger.Instance;
} }
/// <summary> /// <summary>
@ -55,9 +60,9 @@ public class LLModel : ILLModel
/// </summary> /// </summary>
/// <param name="handle">Pointer to underlying model</param> /// <param name="handle">Pointer to underlying model</param>
/// <param name="modelType">The model type</param> /// <param name="modelType">The model type</param>
public static LLModel Create(IntPtr handle, ModelType modelType) public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{ {
return new LLModel(handle, modelType); return new LLModel(handle, modelType, logger: logger);
} }
/// <summary> /// <summary>
@ -82,6 +87,8 @@ public class LLModel : ILLModel
GC.KeepAlive(recalculateCallback); GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken); GC.KeepAlive(cancellationToken);
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
NativeMethods.llmodel_prompt( NativeMethods.llmodel_prompt(
_handle, _handle,
text, text,
@ -94,7 +101,12 @@ public class LLModel : ILLModel
}, },
(tokenId, response) => (tokenId, response) =>
{ {
if (cancellationToken.IsCancellationRequested) return false; if (cancellationToken.IsCancellationRequested)
{
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
return false;
}
if (responseCallback == null) return true; if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response); var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args); return responseCallback(args);

@ -1,6 +1,4 @@
using System.Reflection; namespace Gpt4All.Bindings;
namespace Gpt4All.Bindings;
/// <summary> /// <summary>
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context. /// Wrapper around the llmodel_prompt_context structure for holding the prompt context.

@ -0,0 +1,26 @@
using Gpt4All.Bindings;
namespace Gpt4All;
internal static class LLPromptContextExtensions
{
public static string Dump(this LLModelPromptContext context)
{
var ctx = context.UnderlyingContext;
return @$"
{{
logits_size = {ctx.logits_size}
tokens_size = {ctx.tokens_size}
n_past = {ctx.n_past}
n_ctx = {ctx.n_ctx}
n_predict = {ctx.n_predict}
top_k = {ctx.top_k}
top_p = {ctx.top_p}
temp = {ctx.temp}
n_batch = {ctx.n_batch}
repeat_penalty = {ctx.repeat_penalty}
repeat_last_n = {ctx.repeat_last_n}
context_erase = {ctx.context_erase}
}}";
}
}

@ -1,6 +1,6 @@
using Gpt4All.Bindings; using Gpt4All.Bindings;
namespace Gpt4All.Extensions; namespace Gpt4All;
public static class PredictRequestOptionsExtensions public static class PredictRequestOptionsExtensions
{ {

@ -1,18 +1,25 @@
using Gpt4All.Bindings; using System.Diagnostics;
using Gpt4All.Extensions; using Gpt4All.Bindings;
using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All; namespace Gpt4All;
public class Gpt4All : IGpt4AllModel public class Gpt4All : IGpt4AllModel
{ {
private readonly ILLModel _model; private readonly ILLModel _model;
private readonly ILogger _logger;
private const string ResponseErrorMessage =
"The model reported an error during token generation error={ResponseError}";
/// <inheritdoc/> /// <inheritdoc/>
public IPromptFormatter? PromptFormatter { get; set; } public IPromptFormatter? PromptFormatter { get; set; }
internal Gpt4All(ILLModel model) internal Gpt4All(ILLModel model, ILogger? logger = null)
{ {
_model = model; _model = model;
_logger = logger ?? NullLogger.Instance;
PromptFormatter = new DefaultPromptFormatter(); PromptFormatter = new DefaultPromptFormatter();
} }
@ -29,21 +36,36 @@ public class Gpt4All : IGpt4AllModel
return Task.Run(() => return Task.Run(() =>
{ {
_logger.LogInformation("Start prediction task");
var sw = Stopwatch.StartNew();
var result = new TextPredictionResult(); var result = new TextPredictionResult();
var context = opts.ToPromptContext(); var context = opts.ToPromptContext();
var prompt = FormatPrompt(text); var prompt = FormatPrompt(text);
_model.Prompt(prompt, context, responseCallback: e => try
{ {
if (e.IsError) _model.Prompt(prompt, context, responseCallback: e =>
{ {
result.Success = false; if (e.IsError)
result.ErrorMessage = e.Response; {
return false; _logger.LogWarning(ResponseErrorMessage, e.Response);
} result.Success = false;
result.Append(e.Response); result.ErrorMessage = e.Response;
return true; return false;
}, cancellationToken: cancellationToken); }
result.Append(e.Response);
return true;
}, cancellationToken: cancellationToken);
}
catch (Exception e)
{
_logger.LogError(e, "Prompt error");
result.Success = false;
}
sw.Stop();
_logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds);
return (ITextPredictionResult)result; return (ITextPredictionResult)result;
}, CancellationToken.None); }, CancellationToken.None);
@ -57,6 +79,9 @@ public class Gpt4All : IGpt4AllModel
_ = Task.Run(() => _ = Task.Run(() =>
{ {
_logger.LogInformation("Start streaming prediction task");
var sw = Stopwatch.StartNew();
try try
{ {
var context = opts.ToPromptContext(); var context = opts.ToPromptContext();
@ -66,6 +91,7 @@ public class Gpt4All : IGpt4AllModel
{ {
if (e.IsError) if (e.IsError)
{ {
_logger.LogWarning(ResponseErrorMessage, e.Response);
result.Success = false; result.Success = false;
result.ErrorMessage = e.Response; result.ErrorMessage = e.Response;
return false; return false;
@ -74,9 +100,16 @@ public class Gpt4All : IGpt4AllModel
return true; return true;
}, cancellationToken: cancellationToken); }, cancellationToken: cancellationToken);
} }
catch (Exception e)
{
_logger.LogError(e, "Prompt error");
result.Success = false;
}
finally finally
{ {
result.Complete(); result.Complete();
sw.Stop();
_logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds);
} }
}, CancellationToken.None); }, CancellationToken.None);

@ -20,4 +20,8 @@
<!-- Linux --> <!-- Linux -->
<None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" /> <None Condition="$([MSBuild]::IsOSPlatform('Linux'))" Include="..\runtimes\linux-x64\native\*.so" Visible="False" CopyToOutputDirectory="PreserveNewest" />
</ItemGroup> </ItemGroup>
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.0" />
</ItemGroup>
</Project> </Project>

@ -1,14 +1,27 @@
using Gpt4All.Bindings; using System.Diagnostics;
using System.Diagnostics; using Microsoft.Extensions.Logging;
using Gpt4All.Bindings;
using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All; namespace Gpt4All;
public class Gpt4AllModelFactory : IGpt4AllModelFactory public class Gpt4AllModelFactory : IGpt4AllModelFactory
{ {
private static IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null) private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null)
{
_loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
_logger = _loggerFactory.CreateLogger<Gpt4AllModelFactory>();
}
private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
{ {
var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
_logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
var handle = modelType_ switch var handle = modelType_ switch
{ {
ModelType.LLAMA => NativeMethods.llmodel_llama_create(), ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
@ -17,18 +30,25 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
_ => NativeMethods.llmodel_model_create(modelPath), _ => NativeMethods.llmodel_model_create(modelPath),
}; };
var loadedSuccesfully = NativeMethods.llmodel_loadModel(handle, modelPath); _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
_logger.LogInformation("Model loading started");
var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
if (loadedSuccesfully == false) _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
if (loadedSuccessfully == false)
{ {
throw new Exception($"Failed to load model: '{modelPath}'"); throw new Exception($"Failed to load model: '{modelPath}'");
} }
var underlyingModel = LLModel.Create(handle, modelType_); var logger = _loggerFactory.CreateLogger<LLModel>();
var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
Debug.Assert(underlyingModel.IsLoaded()); Debug.Assert(underlyingModel.IsLoaded());
return new Gpt4All(underlyingModel); return new Gpt4All(underlyingModel, logger: logger);
} }
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null); public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);

Loading…
Cancel
Save