using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Text;
namespace LLama
{
using llama_token = Int32;
///
/// Cache for a llama.cpp model.
///
public class LLamaCache
{
private Dictionary>> _cacheState;
private LinkedList> _cacheList;
private int _capacity;
public int CacheSize
{
get
{
return _cacheState.Values.Select(s => s.Value.Value.Size).Sum();
}
}
///
///
///
/// The max capacity (bytes).
public LLamaCache(int capacity = 2 << 30)
{
_cacheState = new();
_cacheList = new();
_capacity = capacity;
}
public LLamaState this[llama_token[] key]
{
get
{
var prefixKey = FindLongestPrefixKey(key);
if(prefixKey is null)
{
throw new KeyNotFoundException();
}
var value = _cacheState[prefixKey];
MoveNodeToEnd(prefixKey);
return value.Value.Value;
}
set
{
var node = _cacheList.AddLast(new KeyValuePair(key, value));
_cacheState[key] = node;
while(CacheSize > _capacity && _cacheList.Count > 0)
{
var topop = _cacheList.First;
_cacheState.Remove(topop.Value.Key);
_cacheList.RemoveFirst();
}
}
}
public bool Contains(llama_token[] key)
{
return FindLongestPrefixKey(key) is not null;
}
private llama_token[]? FindLongestPrefixKey(llama_token[] key)
{
int minLen = 0;
llama_token[]? minKey = null;
var keys = _cacheState.Keys.Select(k => (k, LLamaModelV1.LongestTokenPrefix(k, key)));
foreach(var (k, prefixLen) in keys)
{
if(prefixLen > minLen)
{
minLen = prefixLen;
minKey = k;
}
}
return minKey;
}
private void MoveNodeToEnd(llama_token[] key)
{
if (!_cacheState.TryGetValue(key, out var node))
{
return;
}
_cacheState.Remove(key);
_cacheList.Remove(node);
var newNode = _cacheList.AddLast(new KeyValuePair(key, node.Value.Value));
_cacheState.Add(key, newNode);
}
}
}