diff --git a/gpt4all-bindings/csharp/Directory.Build.props b/gpt4all-bindings/csharp/Directory.Build.props index 8ce38889..9f7cf5bf 100644 --- a/gpt4all-bindings/csharp/Directory.Build.props +++ b/gpt4all-bindings/csharp/Directory.Build.props @@ -5,7 +5,7 @@ en-US - 0.6.0 + 0.6.1-alpha $(VersionSuffix) $(Version)$(VersionSuffix) true diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs index bcda1d85..206b00cf 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs @@ -1,4 +1,7 @@ -namespace Gpt4All.Bindings; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; + +namespace Gpt4All.Bindings; /// /// Arguments for the response processing callback @@ -40,14 +43,16 @@ public class LLModel : ILLModel { protected readonly IntPtr _handle; private readonly ModelType _modelType; + private readonly ILogger _logger; private bool _disposed; public ModelType ModelType => _modelType; - internal LLModel(IntPtr handle, ModelType modelType) + internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null) { _handle = handle; _modelType = modelType; + _logger = logger ?? NullLogger.Instance; } /// @@ -55,9 +60,9 @@ public class LLModel : ILLModel /// /// Pointer to underlying model /// The model type - public static LLModel Create(IntPtr handle, ModelType modelType) + public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null) { - return new LLModel(handle, modelType); + return new LLModel(handle, modelType, logger: logger); } /// @@ -82,6 +87,8 @@ public class LLModel : ILLModel GC.KeepAlive(recalculateCallback); GC.KeepAlive(cancellationToken); + _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump()); + NativeMethods.llmodel_prompt( _handle, text, @@ -94,7 +101,12 @@ public class LLModel : ILLModel }, (tokenId, response) => { - if (cancellationToken.IsCancellationRequested) return false; + if (cancellationToken.IsCancellationRequested) + { + _logger.LogDebug("ResponseCallback evt=CancellationRequested"); + return false; + } + if (responseCallback == null) return true; var args = new ModelResponseEventArgs(tokenId, response); return responseCallback(args); diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs index ce78938f..eeec504d 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs @@ -1,6 +1,4 @@ -using System.Reflection; - -namespace Gpt4All.Bindings; +namespace Gpt4All.Bindings; /// /// Wrapper around the llmodel_prompt_context structure for holding the prompt context. diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs new file mode 100644 index 00000000..4426ef49 --- /dev/null +++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs @@ -0,0 +1,26 @@ +using Gpt4All.Bindings; + +namespace Gpt4All; + +internal static class LLPromptContextExtensions +{ + public static string Dump(this LLModelPromptContext context) + { + var ctx = context.UnderlyingContext; + return @$" + {{ + logits_size = {ctx.logits_size} + tokens_size = {ctx.tokens_size} + n_past = {ctx.n_past} + n_ctx = {ctx.n_ctx} + n_predict = {ctx.n_predict} + top_k = {ctx.top_k} + top_p = {ctx.top_p} + temp = {ctx.temp} + n_batch = {ctx.n_batch} + repeat_penalty = {ctx.repeat_penalty} + repeat_last_n = {ctx.repeat_last_n} + context_erase = {ctx.context_erase} + }}"; + } +} diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs index f6e1f016..48ebd1f1 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs @@ -1,6 +1,6 @@ using Gpt4All.Bindings; -namespace Gpt4All.Extensions; +namespace Gpt4All; public static class PredictRequestOptionsExtensions { diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs index d129ace4..318361a0 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs @@ -1,18 +1,25 @@ -using Gpt4All.Bindings; -using Gpt4All.Extensions; +using System.Diagnostics; +using Gpt4All.Bindings; +using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Logging.Abstractions; namespace Gpt4All; public class Gpt4All : IGpt4AllModel { private readonly ILLModel _model; + private readonly ILogger _logger; + + private const string ResponseErrorMessage = + "The model reported an error during token generation error={ResponseError}"; /// public IPromptFormatter? PromptFormatter { get; set; } - internal Gpt4All(ILLModel model) + internal Gpt4All(ILLModel model, ILogger? logger = null) { _model = model; + _logger = logger ?? NullLogger.Instance; PromptFormatter = new DefaultPromptFormatter(); } @@ -29,21 +36,36 @@ public class Gpt4All : IGpt4AllModel return Task.Run(() => { + _logger.LogInformation("Start prediction task"); + + var sw = Stopwatch.StartNew(); var result = new TextPredictionResult(); var context = opts.ToPromptContext(); var prompt = FormatPrompt(text); - _model.Prompt(prompt, context, responseCallback: e => + try { - if (e.IsError) + _model.Prompt(prompt, context, responseCallback: e => { - result.Success = false; - result.ErrorMessage = e.Response; - return false; - } - result.Append(e.Response); - return true; - }, cancellationToken: cancellationToken); + if (e.IsError) + { + _logger.LogWarning(ResponseErrorMessage, e.Response); + result.Success = false; + result.ErrorMessage = e.Response; + return false; + } + result.Append(e.Response); + return true; + }, cancellationToken: cancellationToken); + } + catch (Exception e) + { + _logger.LogError(e, "Prompt error"); + result.Success = false; + } + + sw.Stop(); + _logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds); return (ITextPredictionResult)result; }, CancellationToken.None); @@ -57,6 +79,9 @@ public class Gpt4All : IGpt4AllModel _ = Task.Run(() => { + _logger.LogInformation("Start streaming prediction task"); + var sw = Stopwatch.StartNew(); + try { var context = opts.ToPromptContext(); @@ -66,6 +91,7 @@ public class Gpt4All : IGpt4AllModel { if (e.IsError) { + _logger.LogWarning(ResponseErrorMessage, e.Response); result.Success = false; result.ErrorMessage = e.Response; return false; @@ -74,9 +100,16 @@ public class Gpt4All : IGpt4AllModel return true; }, cancellationToken: cancellationToken); } + catch (Exception e) + { + _logger.LogError(e, "Prompt error"); + result.Success = false; + } finally { result.Complete(); + sw.Stop(); + _logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds); } }, CancellationToken.None); diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj index dc2d96fa..db6780fe 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj +++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj @@ -20,4 +20,8 @@ + + + + diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs index 6d4c7875..3c36ac26 100644 --- a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs +++ b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs @@ -1,14 +1,27 @@ -using Gpt4All.Bindings; -using System.Diagnostics; +using System.Diagnostics; +using Microsoft.Extensions.Logging; +using Gpt4All.Bindings; +using Microsoft.Extensions.Logging.Abstractions; namespace Gpt4All; public class Gpt4AllModelFactory : IGpt4AllModelFactory { - private static IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null) + private readonly ILoggerFactory _loggerFactory; + private readonly ILogger _logger; + + public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null) + { + _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance; + _logger = _loggerFactory.CreateLogger(); + } + + private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null) { var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath); + _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_); + var handle = modelType_ switch { ModelType.LLAMA => NativeMethods.llmodel_llama_create(), @@ -17,18 +30,25 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory _ => NativeMethods.llmodel_model_create(modelPath), }; - var loadedSuccesfully = NativeMethods.llmodel_loadModel(handle, modelPath); + _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle); + _logger.LogInformation("Model loading started"); - if (loadedSuccesfully == false) + var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath); + + _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully); + + if (loadedSuccessfully == false) { throw new Exception($"Failed to load model: '{modelPath}'"); } - var underlyingModel = LLModel.Create(handle, modelType_); + var logger = _loggerFactory.CreateLogger(); + + var underlyingModel = LLModel.Create(handle, modelType_, logger: logger); Debug.Assert(underlyingModel.IsLoaded()); - return new Gpt4All(underlyingModel); + return new Gpt4All(underlyingModel, logger: logger); } public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);