@ -1,18 +1,25 @@
using Gpt4All.Bindings ;
using System.Diagnostics ;
using Gpt4All.Extensions ;
using Gpt4All.Bindings ;
using Microsoft.Extensions.Logging ;
using Microsoft.Extensions.Logging.Abstractions ;
namespace Gpt4All ;
namespace Gpt4All ;
public class Gpt4All : IGpt4AllModel
public class Gpt4All : IGpt4AllModel
{
{
private readonly ILLModel _model ;
private readonly ILLModel _model ;
private readonly ILogger _logger ;
private const string ResponseErrorMessage =
"The model reported an error during token generation error={ResponseError}" ;
/// <inheritdoc/>
/// <inheritdoc/>
public IPromptFormatter ? PromptFormatter { get ; set ; }
public IPromptFormatter ? PromptFormatter { get ; set ; }
internal Gpt4All ( ILLModel model )
internal Gpt4All ( ILLModel model , ILogger ? logger = null )
{
{
_model = model ;
_model = model ;
_logger = logger ? ? NullLogger . Instance ;
PromptFormatter = new DefaultPromptFormatter ( ) ;
PromptFormatter = new DefaultPromptFormatter ( ) ;
}
}
@ -29,21 +36,36 @@ public class Gpt4All : IGpt4AllModel
return Task . Run ( ( ) = >
return Task . Run ( ( ) = >
{
{
_logger . LogInformation ( "Start prediction task" ) ;
var sw = Stopwatch . StartNew ( ) ;
var result = new TextPredictionResult ( ) ;
var result = new TextPredictionResult ( ) ;
var context = opts . ToPromptContext ( ) ;
var context = opts . ToPromptContext ( ) ;
var prompt = FormatPrompt ( text ) ;
var prompt = FormatPrompt ( text ) ;
_model . Prompt ( prompt , context , responseCallback : e = >
try
{
{
if ( e . IsError )
_model . Prompt ( prompt , context , responseCallback : e = >
{
{
result . Success = false ;
if ( e . IsError )
result . ErrorMessage = e . Response ;
{
return false ;
_logger . LogWarning ( ResponseErrorMessage , e . Response ) ;
}
result . Success = false ;
result . Append ( e . Response ) ;
result . ErrorMessage = e . Response ;
return true ;
return false ;
} , cancellationToken : cancellationToken ) ;
}
result . Append ( e . Response ) ;
return true ;
} , cancellationToken : cancellationToken ) ;
}
catch ( Exception e )
{
_logger . LogError ( e , "Prompt error" ) ;
result . Success = false ;
}
sw . Stop ( ) ;
_logger . LogInformation ( "Prediction task completed elapsed={Elapsed}s" , sw . Elapsed . TotalSeconds ) ;
return ( ITextPredictionResult ) result ;
return ( ITextPredictionResult ) result ;
} , CancellationToken . None ) ;
} , CancellationToken . None ) ;
@ -57,6 +79,9 @@ public class Gpt4All : IGpt4AllModel
_ = Task . Run ( ( ) = >
_ = Task . Run ( ( ) = >
{
{
_logger . LogInformation ( "Start streaming prediction task" ) ;
var sw = Stopwatch . StartNew ( ) ;
try
try
{
{
var context = opts . ToPromptContext ( ) ;
var context = opts . ToPromptContext ( ) ;
@ -66,6 +91,7 @@ public class Gpt4All : IGpt4AllModel
{
{
if ( e . IsError )
if ( e . IsError )
{
{
_logger . LogWarning ( ResponseErrorMessage , e . Response ) ;
result . Success = false ;
result . Success = false ;
result . ErrorMessage = e . Response ;
result . ErrorMessage = e . Response ;
return false ;
return false ;
@ -74,9 +100,16 @@ public class Gpt4All : IGpt4AllModel
return true ;
return true ;
} , cancellationToken : cancellationToken ) ;
} , cancellationToken : cancellationToken ) ;
}
}
catch ( Exception e )
{
_logger . LogError ( e , "Prompt error" ) ;
result . Success = false ;
}
finally
finally
{
{
result . Complete ( ) ;
result . Complete ( ) ;
sw . Stop ( ) ;
_logger . LogInformation ( "Prediction task completed elapsed={Elapsed}s" , sw . Elapsed . TotalSeconds ) ;
}
}
} , CancellationToken . None ) ;
} , CancellationToken . None ) ;