You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

LLamaContextExtensions.cs 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. using System;
  2. using System.Buffers.Binary;
  3. using System.IO;
  4. using System.IO.MemoryMappedFiles;
  5. using LLama.Native;
  6. namespace LLama.Batched;
  7. internal static class LLamaContextExtensions
  8. {
  9. private const uint FileHeaderMagic = 3430400180;
  10. /// <summary>
  11. /// Save the state of a particular sequence to specified path. Also save some extra data which will be returned when loading.
  12. /// Data saved with this method <b>must</b> be saved with <see cref="LoadState(LLamaContext, string, LLamaSeqId, out byte[])"/>
  13. /// </summary>
  14. /// <param name="context"></param>
  15. /// <param name="filename"></param>
  16. /// <param name="sequence"></param>
  17. /// <param name="header"></param>
  18. internal static void SaveState(this LLamaContext context, string filename, LLamaSeqId sequence, ReadOnlySpan<byte> header)
  19. {
  20. // Delete that file before overwriting it
  21. if (File.Exists(filename))
  22. File.Delete(filename);
  23. // Estimate size of state to write to disk, this is always equal to or greater than the actual size
  24. var estimatedStateSize = checked((long)context.NativeHandle.GetStateSize(sequence));
  25. // Space for "extra" byte plus a 8 byte header
  26. var prefixSize = header.Length + 8;
  27. // Add enough space for the "extra" data and a 6 byte header
  28. var totalFileSize = prefixSize + estimatedStateSize;
  29. // Map the file and write the bytes directly to it.
  30. long writtenBytes = 0;
  31. using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Create, null, totalFileSize))
  32. {
  33. using (var view = file.CreateViewAccessor(0, totalFileSize))
  34. {
  35. unsafe
  36. {
  37. byte* ptr = null;
  38. view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
  39. try
  40. {
  41. // Write prefix data
  42. BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), FileHeaderMagic);
  43. writtenBytes += 4;
  44. BinaryPrimitives.WriteUInt32BigEndian(new Span<byte>(ptr + writtenBytes, 4), (uint)header.Length);
  45. writtenBytes += 4;
  46. header.CopyTo(new Span<byte>(ptr + writtenBytes, header.Length));
  47. writtenBytes += header.Length;
  48. // Write state data
  49. writtenBytes += (long)context.NativeHandle.GetState(ptr + writtenBytes, (ulong)estimatedStateSize, sequence);
  50. }
  51. finally
  52. {
  53. view.SafeMemoryMappedViewHandle.ReleasePointer();
  54. }
  55. }
  56. }
  57. }
  58. // Truncate the file to the actual size of data that was written
  59. using (var fileStream = new FileStream(filename, FileMode.Open))
  60. fileStream.SetLength(writtenBytes);
  61. }
  62. /// <summary>
  63. /// Load the state from the specified path into a particular sequence. Also reading header data. Must only be used with
  64. /// data previously saved with <see cref="SaveState(LLamaContext, string, LLamaSeqId, ReadOnlySpan{byte})"/>
  65. /// </summary>
  66. /// <param name="context"></param>
  67. /// <param name="filename"></param>
  68. /// <param name="sequence"></param>
  69. /// <param name="header"></param>
  70. /// <exception cref="InvalidOperationException"></exception>
  71. internal static void LoadState(this LLamaContext context, string filename, LLamaSeqId sequence, out byte[] header)
  72. {
  73. // Map state file into memory and pass that pointer directly to `llama_set_state_data` to load from
  74. using (var file = MemoryMappedFile.CreateFromFile(filename, FileMode.Open, null))
  75. using (var view = file.CreateViewAccessor())
  76. {
  77. unsafe
  78. {
  79. byte* ptr = null;
  80. view.SafeMemoryMappedViewHandle.AcquirePointer(ref ptr);
  81. try
  82. {
  83. var readBytes = 0;
  84. // Read header
  85. var magic = BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4));
  86. readBytes += 4;
  87. if (magic != FileHeaderMagic)
  88. throw new InvalidOperationException("Invalid file header");
  89. var headerLength = checked((int)BinaryPrimitives.ReadUInt32BigEndian(new ReadOnlySpan<byte>(ptr + readBytes, 4)));
  90. readBytes += 4;
  91. header = new byte[headerLength];
  92. new Span<byte>(ptr + readBytes, headerLength).CopyTo(header);
  93. readBytes += headerLength;
  94. context.NativeHandle.SetState(ptr + readBytes, sequence);
  95. }
  96. finally
  97. {
  98. view.SafeMemoryMappedViewHandle.ReleasePointer();
  99. }
  100. }
  101. }
  102. }
  103. }