| @@ -70,6 +70,8 @@ namespace LLama | |||||
| /// </summary> | /// </summary> | ||||
| protected float? MirostatMu { get; set; } | protected float? MirostatMu { get; set; } | ||||
| private StreamingTokenDecoder _decoder; | |||||
| /// <summary> | /// <summary> | ||||
| /// | /// | ||||
| /// </summary> | /// </summary> | ||||
| @@ -83,6 +85,7 @@ namespace LLama | |||||
| _consumedTokensCount = 0; | _consumedTokensCount = 0; | ||||
| _n_session_consumed = 0; | _n_session_consumed = 0; | ||||
| _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0); | _last_n_tokens = new FixedSizeQueue<llama_token>(Context.ContextSize).FillWith(0); | ||||
| _decoder = new StreamingTokenDecoder(context); | |||||
| } | } | ||||
| /// <summary> | /// <summary> | ||||
| @@ -294,7 +297,10 @@ namespace LLama | |||||
| await InferInternal(inferenceParams, args); | await InferInternal(inferenceParams, args); | ||||
| if (args.ReturnValue) | if (args.ReturnValue) | ||||
| yield return Context.DeTokenize(_embeds); | |||||
| { | |||||
| _decoder.AddRange(_embeds); | |||||
| yield return _decoder.Read(); | |||||
| } | |||||
| var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); | var (breakGeneration, extraOutputs) = await PostProcess(inferenceParams, args); | ||||
| if (extraOutputs is { Count: > 0 }) | if (extraOutputs is { Count: > 0 }) | ||||