diff --git a/gpt4all-bindings/csharp/Directory.Build.props b/gpt4all-bindings/csharp/Directory.Build.props
index 8ce38889..9f7cf5bf 100644
--- a/gpt4all-bindings/csharp/Directory.Build.props
+++ b/gpt4all-bindings/csharp/Directory.Build.props
@@ -5,7 +5,7 @@
en-US
- 0.6.0
+ 0.6.1-alpha
$(VersionSuffix)
$(Version)$(VersionSuffix)
true
diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
index bcda1d85..206b00cf 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLModel.cs
@@ -1,4 +1,7 @@
-namespace Gpt4All.Bindings;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
+
+namespace Gpt4All.Bindings;
///
/// Arguments for the response processing callback
@@ -40,14 +43,16 @@ public class LLModel : ILLModel
{
protected readonly IntPtr _handle;
private readonly ModelType _modelType;
+ private readonly ILogger _logger;
private bool _disposed;
public ModelType ModelType => _modelType;
- internal LLModel(IntPtr handle, ModelType modelType)
+ internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
_handle = handle;
_modelType = modelType;
+ _logger = logger ?? NullLogger.Instance;
}
///
@@ -55,9 +60,9 @@ public class LLModel : ILLModel
///
/// Pointer to underlying model
/// The model type
- public static LLModel Create(IntPtr handle, ModelType modelType)
+ public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
{
- return new LLModel(handle, modelType);
+ return new LLModel(handle, modelType, logger: logger);
}
///
@@ -82,6 +87,8 @@ public class LLModel : ILLModel
GC.KeepAlive(recalculateCallback);
GC.KeepAlive(cancellationToken);
+ _logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
+
NativeMethods.llmodel_prompt(
_handle,
text,
@@ -94,7 +101,12 @@ public class LLModel : ILLModel
},
(tokenId, response) =>
{
- if (cancellationToken.IsCancellationRequested) return false;
+ if (cancellationToken.IsCancellationRequested)
+ {
+ _logger.LogDebug("ResponseCallback evt=CancellationRequested");
+ return false;
+ }
+
if (responseCallback == null) return true;
var args = new ModelResponseEventArgs(tokenId, response);
return responseCallback(args);
diff --git a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
index ce78938f..eeec504d 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Bindings/LLPromptContext.cs
@@ -1,6 +1,4 @@
-using System.Reflection;
-
-namespace Gpt4All.Bindings;
+namespace Gpt4All.Bindings;
///
/// Wrapper around the llmodel_prompt_context structure for holding the prompt context.
diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs
new file mode 100644
index 00000000..4426ef49
--- /dev/null
+++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/LLPromptContextExtensions.cs
@@ -0,0 +1,26 @@
+using Gpt4All.Bindings;
+
+namespace Gpt4All;
+
+internal static class LLPromptContextExtensions
+{
+ public static string Dump(this LLModelPromptContext context)
+ {
+ var ctx = context.UnderlyingContext;
+ return @$"
+ {{
+ logits_size = {ctx.logits_size}
+ tokens_size = {ctx.tokens_size}
+ n_past = {ctx.n_past}
+ n_ctx = {ctx.n_ctx}
+ n_predict = {ctx.n_predict}
+ top_k = {ctx.top_k}
+ top_p = {ctx.top_p}
+ temp = {ctx.temp}
+ n_batch = {ctx.n_batch}
+ repeat_penalty = {ctx.repeat_penalty}
+ repeat_last_n = {ctx.repeat_last_n}
+ context_erase = {ctx.context_erase}
+ }}";
+ }
+}
diff --git a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs
index f6e1f016..48ebd1f1 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Extensions/PredictRequestOptionsExtensions.cs
@@ -1,6 +1,6 @@
using Gpt4All.Bindings;
-namespace Gpt4All.Extensions;
+namespace Gpt4All;
public static class PredictRequestOptionsExtensions
{
diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs
index d129ace4..318361a0 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.cs
@@ -1,18 +1,25 @@
-using Gpt4All.Bindings;
-using Gpt4All.Extensions;
+using System.Diagnostics;
+using Gpt4All.Bindings;
+using Microsoft.Extensions.Logging;
+using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All;
public class Gpt4All : IGpt4AllModel
{
private readonly ILLModel _model;
+ private readonly ILogger _logger;
+
+ private const string ResponseErrorMessage =
+ "The model reported an error during token generation error={ResponseError}";
///
public IPromptFormatter? PromptFormatter { get; set; }
- internal Gpt4All(ILLModel model)
+ internal Gpt4All(ILLModel model, ILogger? logger = null)
{
_model = model;
+ _logger = logger ?? NullLogger.Instance;
PromptFormatter = new DefaultPromptFormatter();
}
@@ -29,21 +36,36 @@ public class Gpt4All : IGpt4AllModel
return Task.Run(() =>
{
+ _logger.LogInformation("Start prediction task");
+
+ var sw = Stopwatch.StartNew();
var result = new TextPredictionResult();
var context = opts.ToPromptContext();
var prompt = FormatPrompt(text);
- _model.Prompt(prompt, context, responseCallback: e =>
+ try
{
- if (e.IsError)
+ _model.Prompt(prompt, context, responseCallback: e =>
{
- result.Success = false;
- result.ErrorMessage = e.Response;
- return false;
- }
- result.Append(e.Response);
- return true;
- }, cancellationToken: cancellationToken);
+ if (e.IsError)
+ {
+ _logger.LogWarning(ResponseErrorMessage, e.Response);
+ result.Success = false;
+ result.ErrorMessage = e.Response;
+ return false;
+ }
+ result.Append(e.Response);
+ return true;
+ }, cancellationToken: cancellationToken);
+ }
+ catch (Exception e)
+ {
+ _logger.LogError(e, "Prompt error");
+ result.Success = false;
+ }
+
+ sw.Stop();
+ _logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds);
return (ITextPredictionResult)result;
}, CancellationToken.None);
@@ -57,6 +79,9 @@ public class Gpt4All : IGpt4AllModel
_ = Task.Run(() =>
{
+ _logger.LogInformation("Start streaming prediction task");
+ var sw = Stopwatch.StartNew();
+
try
{
var context = opts.ToPromptContext();
@@ -66,6 +91,7 @@ public class Gpt4All : IGpt4AllModel
{
if (e.IsError)
{
+ _logger.LogWarning(ResponseErrorMessage, e.Response);
result.Success = false;
result.ErrorMessage = e.Response;
return false;
@@ -74,9 +100,16 @@ public class Gpt4All : IGpt4AllModel
return true;
}, cancellationToken: cancellationToken);
}
+ catch (Exception e)
+ {
+ _logger.LogError(e, "Prompt error");
+ result.Success = false;
+ }
finally
{
result.Complete();
+ sw.Stop();
+ _logger.LogInformation("Prediction task completed elapsed={Elapsed}s", sw.Elapsed.TotalSeconds);
}
}, CancellationToken.None);
diff --git a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
index dc2d96fa..db6780fe 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
+++ b/gpt4all-bindings/csharp/Gpt4All/Gpt4All.csproj
@@ -20,4 +20,8 @@
+
+
+
+
diff --git a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
index 6d4c7875..3c36ac26 100644
--- a/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
+++ b/gpt4all-bindings/csharp/Gpt4All/Model/Gpt4AllModelFactory.cs
@@ -1,14 +1,27 @@
-using Gpt4All.Bindings;
-using System.Diagnostics;
+using System.Diagnostics;
+using Microsoft.Extensions.Logging;
+using Gpt4All.Bindings;
+using Microsoft.Extensions.Logging.Abstractions;
namespace Gpt4All;
public class Gpt4AllModelFactory : IGpt4AllModelFactory
{
- private static IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
+ private readonly ILoggerFactory _loggerFactory;
+ private readonly ILogger _logger;
+
+ public Gpt4AllModelFactory(ILoggerFactory? loggerFactory = null)
+ {
+ _loggerFactory = loggerFactory ?? NullLoggerFactory.Instance;
+ _logger = _loggerFactory.CreateLogger();
+ }
+
+ private IGpt4AllModel CreateModel(string modelPath, ModelType? modelType = null)
{
var modelType_ = modelType ?? ModelFileUtils.GetModelTypeFromModelFileHeader(modelPath);
+ _logger.LogInformation("Creating model path={ModelPath} type={ModelType}", modelPath, modelType_);
+
var handle = modelType_ switch
{
ModelType.LLAMA => NativeMethods.llmodel_llama_create(),
@@ -17,18 +30,25 @@ public class Gpt4AllModelFactory : IGpt4AllModelFactory
_ => NativeMethods.llmodel_model_create(modelPath),
};
- var loadedSuccesfully = NativeMethods.llmodel_loadModel(handle, modelPath);
+ _logger.LogDebug("Model created handle=0x{ModelHandle:X8}", handle);
+ _logger.LogInformation("Model loading started");
- if (loadedSuccesfully == false)
+ var loadedSuccessfully = NativeMethods.llmodel_loadModel(handle, modelPath);
+
+ _logger.LogInformation("Model loading completed success={ModelLoadSuccess}", loadedSuccessfully);
+
+ if (loadedSuccessfully == false)
{
throw new Exception($"Failed to load model: '{modelPath}'");
}
- var underlyingModel = LLModel.Create(handle, modelType_);
+ var logger = _loggerFactory.CreateLogger();
+
+ var underlyingModel = LLModel.Create(handle, modelType_, logger: logger);
Debug.Assert(underlyingModel.IsLoaded());
- return new Gpt4All(underlyingModel);
+ return new Gpt4All(underlyingModel, logger: logger);
}
public IGpt4AllModel LoadModel(string modelPath) => CreateModel(modelPath, modelType: null);