using LLama.Web.Async; using LLama.Web.Common; using LLama.Web.Models; using Microsoft.Extensions.Options; using System.Collections.Concurrent; namespace LLama.Web.Services { /// /// Service for handling Models,Weights & Contexts /// public class ModelService : IModelService { private readonly AsyncLock _modelLock; private readonly AsyncLock _contextLock; private readonly LLamaOptions _configuration; private readonly ILogger _llamaLogger; private readonly ConcurrentDictionary _modelInstances; /// /// Initializes a new instance of the class. /// /// The logger. /// The options. public ModelService(IOptions configuration, ILogger llamaLogger) { _llamaLogger = llamaLogger; _modelLock = new AsyncLock(); _contextLock = new AsyncLock(); _configuration = configuration.Value; _modelInstances = new ConcurrentDictionary(); } /// /// Loads a model with the provided configuration. /// /// The model configuration. /// public async Task LoadModel(ModelOptions modelOptions) { if (_modelInstances.TryGetValue(modelOptions.Name, out var existingModel)) return existingModel; using (await _modelLock.LockAsync()) { if (_modelInstances.TryGetValue(modelOptions.Name, out var model)) return model; // If in single mode unload any other models if (_configuration.ModelLoadType == ModelLoadType.Single || _configuration.ModelLoadType == ModelLoadType.PreloadSingle) await UnloadModels(); model = new LLamaModel(modelOptions, _llamaLogger); _modelInstances.TryAdd(modelOptions.Name, model); return model; } } /// /// Loads the models. /// public async Task LoadModels() { if (_configuration.ModelLoadType == ModelLoadType.Single || _configuration.ModelLoadType == ModelLoadType.Multiple) return; foreach (var modelConfig in _configuration.Models) { await LoadModel(modelConfig); //Only preload first model if in SinglePreload mode if (_configuration.ModelLoadType == ModelLoadType.PreloadSingle) break; } } /// /// Unloads the model. /// /// Name of the model. /// public Task UnloadModel(string modelName) { if (_modelInstances.TryRemove(modelName, out var model)) { model?.Dispose(); return Task.FromResult(true); } return Task.FromResult(false); } /// /// Unloads all models. /// public async Task UnloadModels() { foreach (var modelName in _modelInstances.Keys) { await UnloadModel(modelName); } } /// /// Gets a model ny name. /// /// Name of the model. /// public Task GetModel(string modelName) { _modelInstances.TryGetValue(modelName, out var model); return Task.FromResult(model); } /// /// Gets a context from the specified model. /// /// Name of the model. /// The contextName. /// /// Model not found public async Task GetContext(string modelName, string contextName) { if (!_modelInstances.TryGetValue(modelName, out var model)) throw new Exception("Model not found"); return await model.GetContext(contextName); } /// /// Creates a context on the specified model. /// /// Name of the model. /// The contextName. /// /// Model not found public async Task CreateContext(string modelName, string contextName) { if (!_modelInstances.TryGetValue(modelName, out var model)) throw new Exception("Model not found"); using (await _contextLock.LockAsync()) { return await model.CreateContext(contextName); } } /// /// Removes a context from the specified model. /// /// Name of the model. /// The contextName. /// /// Model not found public async Task RemoveContext(string modelName, string contextName) { if (!_modelInstances.TryGetValue(modelName, out var model)) throw new Exception("Model not found"); using (await _contextLock.LockAsync()) { return await model.RemoveContext(contextName); } } /// /// Loads, Gets,Creates a Model and a Context /// /// Name of the model. /// The contextName. /// /// Model option '{modelName}' not found public async Task<(LLamaModel, LLamaContext)> GetOrCreateModelAndContext(string modelName, string contextName) { if (_modelInstances.TryGetValue(modelName, out var model)) return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); // Get model configuration var modelConfig = _configuration.Models.FirstOrDefault(x => x.Name == modelName); if (modelConfig is null) throw new Exception($"Model option '{modelName}' not found"); // Load Model model = await LoadModel(modelConfig); // Get or Create Context return (model, await model.GetContext(contextName) ?? await model.CreateContext(contextName)); } } }