using LLama.Web.Async;
using LLama.Web.Common;
using LLama.Web.Models;
using Microsoft.Extensions.Options;
using System.Collections.Concurrent;
namespace LLama.Web.Services
{
///
/// Sercive for handling Models,Weights & Contexts
///
public class ModelService : IModelService
{
private readonly AsyncLock _modelLock;
private readonly AsyncLock _contextLock;
private readonly LLamaOptions _configuration;
private readonly ILogger _llamaLogger;
private readonly ConcurrentDictionary _modelInstances;
///
/// Initializes a new instance of the class.
///
/// The logger.
/// The options.
public ModelService(IOptions configuration, ILogger llamaLogger)
{
_llamaLogger = llamaLogger;
_modelLock = new AsyncLock();
_contextLock = new AsyncLock();
_configuration = configuration.Value;
_modelInstances = new ConcurrentDictionary();
}
///
/// Loads a model with the provided configuration.
///
/// The model configuration.
///
public async Task LoadModel(ModelOptions modelOptions)
{
if (_modelInstances.TryGetValue(modelOptions.Name, out var existingModel))
return existingModel;
using (await _modelLock.LockAsync())
{
if (_modelInstances.TryGetValue(modelOptions.Name, out var model))
return model;
// If in single mode unload any other models
if (_configuration.ModelLoadType == ModelLoadType.Single
|| _configuration.ModelLoadType == ModelLoadType.PreloadSingle)
await UnloadModels();
model = new LLamaModel(modelOptions, _llamaLogger);
_modelInstances.TryAdd(modelOptions.Name, model);
return model;
}
}
///
/// Loads the models.
///
public async Task LoadModels()
{
if (_configuration.ModelLoadType == ModelLoadType.Single
|| _configuration.ModelLoadType == ModelLoadType.Multiple)
return;
foreach (var modelConfig in _configuration.Models)
{
await LoadModel(modelConfig);
//Only preload first model if in SinglePreload mode
if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle)
break;
}
}
///
/// Unloads the model.
///
/// Name of the model.
///
public Task UnloadModel(string modelName)
{
if (_modelInstances.TryRemove(modelName, out var model))
{
model?.Dispose();
return Task.FromResult(true);
}
return Task.FromResult(false);
}
///
/// Unloads all models.
///
public async Task UnloadModels()
{
foreach (var modelName in _modelInstances.Keys)
{
await UnloadModel(modelName);
}
}
///
/// Gets a model ny name.
///
/// Name of the model.
///
public Task GetModel(string modelName)
{
_modelInstances.TryGetValue(modelName, out var model);
return Task.FromResult(model);
}
///
/// Gets a context from the specified model.
///
/// Name of the model.
/// The contextName.
///
/// Model not found
public async Task GetContext(string modelName, string contextName)
{
if (!_modelInstances.TryGetValue(modelName, out var model))
throw new Exception("Model not found");
return await model.GetContext(contextName);
}
///
/// Creates a context on the specified model.
///
/// Name of the model.
/// The contextName.
///
/// Model not found
public async Task CreateContext(string modelName, string contextName)
{
if (!_modelInstances.TryGetValue(modelName, out var model))
throw new Exception("Model not found");
using (await _contextLock.LockAsync())
{
return await model.CreateContext(contextName);
}
}
///
/// Removes a context from the specified model.
///
/// Name of the model.
/// The contextName.
///
/// Model not found
public async Task RemoveContext(string modelName, string contextName)
{
if (!_modelInstances.TryGetValue(modelName, out var model))
throw new Exception("Model not found");
using (await _contextLock.LockAsync())
{
return await model.RemoveContext(contextName);
}
}
///
/// Loads, Gets,Creates a Model and a Context
///
/// Name of the model.
/// The contextName.
///
/// Model option '{modelName}' not found
public async Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName)
{
if (_modelInstances.TryGetValue(modelName, out var model))
return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
// Get model configuration
var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName);
if (modelConfig is null)
throw new Exception($"Model option '{modelName}' not found");
// Load Model
model = await LoadModel(modelConfig);
// Get or Create Context
return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName));
}
}
}