mirror of
https://github.com/nomic-ai/gpt4all
synced 2024-11-06 09:20:33 +00:00
8e89ceb54b
* added optional support for .NET logging * bump version and add missing alpha suffix * avoid creating additional namespace for extensions * prefer NullLogger/NullLoggerFactory over null-conditional ILogger to avoid errors --------- Signed-off-by: mvenditto <venditto.matteo@gmail.com>
248 lines
7.7 KiB
C#
248 lines
7.7 KiB
C#
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Logging.Abstractions;
|
|
|
|
namespace Gpt4All.Bindings;
|
|
|
|
/// <summary>
|
|
/// Arguments for the response processing callback
|
|
/// </summary>
|
|
/// <param name="TokenId">The token id of the response</param>
|
|
/// <param name="Response"> The response string. NOTE: a token_id of -1 indicates the string is an error string</param>
|
|
/// <return>
|
|
/// A bool indicating whether the model should keep generating
|
|
/// </return>
|
|
public record ModelResponseEventArgs(int TokenId, string Response)
|
|
{
|
|
public bool IsError => TokenId == -1;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Arguments for the prompt processing callback
|
|
/// </summary>
|
|
/// <param name="TokenId">The token id of the prompt</param>
|
|
/// <return>
|
|
/// A bool indicating whether the model should keep processing
|
|
/// </return>
|
|
public record ModelPromptEventArgs(int TokenId)
|
|
{
|
|
}
|
|
|
|
/// <summary>
|
|
/// Arguments for the recalculating callback
|
|
/// </summary>
|
|
/// <param name="IsRecalculating"> whether the model is recalculating the context.</param>
|
|
/// <return>
|
|
/// A bool indicating whether the model should keep generating
|
|
/// </return>
|
|
public record ModelRecalculatingEventArgs(bool IsRecalculating);
|
|
|
|
/// <summary>
|
|
/// Base class and universal wrapper for GPT4All language models built around llmodel C-API.
|
|
/// </summary>
|
|
public class LLModel : ILLModel
|
|
{
|
|
protected readonly IntPtr _handle;
|
|
private readonly ModelType _modelType;
|
|
private readonly ILogger _logger;
|
|
private bool _disposed;
|
|
|
|
public ModelType ModelType => _modelType;
|
|
|
|
internal LLModel(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
|
{
|
|
_handle = handle;
|
|
_modelType = modelType;
|
|
_logger = logger ?? NullLogger.Instance;
|
|
}
|
|
|
|
/// <summary>
|
|
/// Create a new model from a pointer
|
|
/// </summary>
|
|
/// <param name="handle">Pointer to underlying model</param>
|
|
/// <param name="modelType">The model type</param>
|
|
public static LLModel Create(IntPtr handle, ModelType modelType, ILogger? logger = null)
|
|
{
|
|
return new LLModel(handle, modelType, logger: logger);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Generate a response using the model
|
|
/// </summary>
|
|
/// <param name="text">The input promp</param>
|
|
/// <param name="context">The context</param>
|
|
/// <param name="promptCallback">A callback function for handling the processing of prompt</param>
|
|
/// <param name="responseCallback">A callback function for handling the generated response</param>
|
|
/// <param name="recalculateCallback">A callback function for handling recalculation requests</param>
|
|
/// <param name="cancellationToken"></param>
|
|
public void Prompt(
|
|
string text,
|
|
LLModelPromptContext context,
|
|
Func<ModelPromptEventArgs, bool>? promptCallback = null,
|
|
Func<ModelResponseEventArgs, bool>? responseCallback = null,
|
|
Func<ModelRecalculatingEventArgs, bool>? recalculateCallback = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
GC.KeepAlive(promptCallback);
|
|
GC.KeepAlive(responseCallback);
|
|
GC.KeepAlive(recalculateCallback);
|
|
GC.KeepAlive(cancellationToken);
|
|
|
|
_logger.LogInformation("Prompt input='{Prompt}' ctx={Context}", text, context.Dump());
|
|
|
|
NativeMethods.llmodel_prompt(
|
|
_handle,
|
|
text,
|
|
(tokenId) =>
|
|
{
|
|
if (cancellationToken.IsCancellationRequested) return false;
|
|
if (promptCallback == null) return true;
|
|
var args = new ModelPromptEventArgs(tokenId);
|
|
return promptCallback(args);
|
|
},
|
|
(tokenId, response) =>
|
|
{
|
|
if (cancellationToken.IsCancellationRequested)
|
|
{
|
|
_logger.LogDebug("ResponseCallback evt=CancellationRequested");
|
|
return false;
|
|
}
|
|
|
|
if (responseCallback == null) return true;
|
|
var args = new ModelResponseEventArgs(tokenId, response);
|
|
return responseCallback(args);
|
|
},
|
|
(isRecalculating) =>
|
|
{
|
|
if (cancellationToken.IsCancellationRequested) return false;
|
|
if (recalculateCallback == null) return true;
|
|
var args = new ModelRecalculatingEventArgs(isRecalculating);
|
|
return recalculateCallback(args);
|
|
},
|
|
ref context.UnderlyingContext
|
|
);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Set the number of threads to be used by the model.
|
|
/// </summary>
|
|
/// <param name="threadCount">The new thread count</param>
|
|
public void SetThreadCount(int threadCount)
|
|
{
|
|
NativeMethods.llmodel_setThreadCount(_handle, threadCount);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Get the number of threads used by the model.
|
|
/// </summary>
|
|
/// <returns>the number of threads used by the model</returns>
|
|
public int GetThreadCount()
|
|
{
|
|
return NativeMethods.llmodel_threadCount(_handle);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Get the size of the internal state of the model.
|
|
/// </summary>
|
|
/// <remarks>
|
|
/// This state data is specific to the type of model you have created.
|
|
/// </remarks>
|
|
/// <returns>the size in bytes of the internal state of the model</returns>
|
|
public ulong GetStateSizeBytes()
|
|
{
|
|
return NativeMethods.llmodel_get_state_size(_handle);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Saves the internal state of the model to the specified destination address.
|
|
/// </summary>
|
|
/// <param name="source">A pointer to the src</param>
|
|
/// <returns>The number of bytes copied</returns>
|
|
public unsafe ulong SaveStateData(byte* source)
|
|
{
|
|
return NativeMethods.llmodel_save_state_data(_handle, source);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Restores the internal state of the model using data from the specified address.
|
|
/// </summary>
|
|
/// <param name="destination">A pointer to destination</param>
|
|
/// <returns>the number of bytes read</returns>
|
|
public unsafe ulong RestoreStateData(byte* destination)
|
|
{
|
|
return NativeMethods.llmodel_restore_state_data(_handle, destination);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Check if the model is loaded.
|
|
/// </summary>
|
|
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
|
public bool IsLoaded()
|
|
{
|
|
return NativeMethods.llmodel_isModelLoaded(_handle);
|
|
}
|
|
|
|
/// <summary>
|
|
/// Load the model from a file.
|
|
/// </summary>
|
|
/// <param name="modelPath">The path to the model file.</param>
|
|
/// <returns>true if the model was loaded successfully, false otherwise.</returns>
|
|
public bool Load(string modelPath)
|
|
{
|
|
return NativeMethods.llmodel_loadModel(_handle, modelPath);
|
|
}
|
|
|
|
protected void Destroy()
|
|
{
|
|
NativeMethods.llmodel_model_destroy(_handle);
|
|
}
|
|
|
|
protected void DestroyLLama()
|
|
{
|
|
NativeMethods.llmodel_llama_destroy(_handle);
|
|
}
|
|
|
|
protected void DestroyGptj()
|
|
{
|
|
NativeMethods.llmodel_gptj_destroy(_handle);
|
|
}
|
|
|
|
protected void DestroyMtp()
|
|
{
|
|
NativeMethods.llmodel_mpt_destroy(_handle);
|
|
}
|
|
|
|
protected virtual void Dispose(bool disposing)
|
|
{
|
|
if (_disposed) return;
|
|
|
|
if (disposing)
|
|
{
|
|
// dispose managed state
|
|
}
|
|
|
|
switch (_modelType)
|
|
{
|
|
case ModelType.LLAMA:
|
|
DestroyLLama();
|
|
break;
|
|
case ModelType.GPTJ:
|
|
DestroyGptj();
|
|
break;
|
|
case ModelType.MPT:
|
|
DestroyMtp();
|
|
break;
|
|
default:
|
|
Destroy();
|
|
break;
|
|
}
|
|
|
|
_disposed = true;
|
|
}
|
|
|
|
public void Dispose()
|
|
{
|
|
Dispose(disposing: true);
|
|
GC.SuppressFinalize(this);
|
|
}
|
|
}
|