From 676c5316f61dde7f2ea51264ca543dcd7cb2bc71 Mon Sep 17 00:00:00 2001 From: Loporrit <141286461+loporrit@users.noreply.github.com> Date: Thu, 24 Aug 2023 18:31:11 +0000 Subject: [PATCH] Avoid buffering file download bundles in to memory --- .../Controllers/CacheController.cs | 15 +-- .../Utils/RequestFileStreamResultFactory.cs | 4 +- .../Utils/StreamUtils.cs | 109 ++++++++++++++++++ 3 files changed, 116 insertions(+), 12 deletions(-) create mode 100644 MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs index 580b6a4..dc3bb0d 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs @@ -34,27 +34,22 @@ public class CacheController : ControllerBase _requestQueue.ActivateRequest(requestId); Response.ContentType = "application/octet-stream"; - var memoryStream = new MemoryStream(); - var streamWriter = new BinaryWriter(memoryStream); long requestSize = 0; + var streamList = new List(); foreach (var file in request.FileIds) { var fs = await _cachedFileProvider.GetAndDownloadFileStream(file); if (fs == null) continue; - streamWriter.Write(Encoding.ASCII.GetBytes("#" + file + ":" + fs.Length.ToString(CultureInfo.InvariantCulture) + "#")); - byte[] buffer = new byte[fs.Length]; - _ = await fs.ReadAsync(buffer, HttpContext.RequestAborted); - streamWriter.Write(buffer); + var headerBytes = Encoding.ASCII.GetBytes("#" + file + ":" + fs.Length.ToString(CultureInfo.InvariantCulture) + "#"); + streamList.Add(new MemoryStream(headerBytes)); + streamList.Add(fs); requestSize += fs.Length; } - streamWriter.Flush(); - memoryStream.Position = 0; - _fileStatisticsService.LogRequest(requestSize); - return _requestFileStreamResultFactory.Create(requestId, memoryStream); + return _requestFileStreamResultFactory.Create(requestId, new ConcatenatedStreamReader(streamList)); } } diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs index e1e05ac..43a952d 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs @@ -17,9 +17,9 @@ public class RequestFileStreamResultFactory _configurationService = configurationService; } - public RequestFileStreamResult Create(Guid requestId, MemoryStream ms) + public RequestFileStreamResult Create(Guid requestId, Stream stream) { return new RequestFileStreamResult(requestId, _requestQueueService, - _metrics, ms, "application/octet-stream"); + _metrics, stream, "application/octet-stream"); } } \ No newline at end of file diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs new file mode 100644 index 0000000..d85ef4f --- /dev/null +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs @@ -0,0 +1,109 @@ +namespace MareSynchronosStaticFilesServer.Utils; + +public class CountedStream : Stream +{ + private readonly Stream _stream; + public ulong BytesRead { get; private set; } + public ulong BytesWritten { get; private set; } + + public CountedStream(Stream underlyingStream) + { + _stream = underlyingStream; + } + + public override bool CanRead => _stream.CanRead; + + public override bool CanSeek => _stream.CanSeek; + + public override bool CanWrite => _stream.CanWrite; + + public override long Length => _stream.Length; + + public override long Position { get => _stream.Position; set => _stream.Position = value; } + + public override void Flush() + { + _stream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = _stream.Read(buffer, offset, count); + BytesRead += (ulong)n; + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _stream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _stream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + BytesWritten += (ulong)count; + _stream.Write(buffer, offset, count); + } +} + +public class ConcatenatedStreamReader : Stream +{ + private IEnumerable _streams; + private IEnumerator _iter; + private bool _finished; + + public ConcatenatedStreamReader(IEnumerable streams) + { + _streams = streams; + _iter = streams.GetEnumerator(); + _finished = !_iter.MoveNext(); + } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = 0; + + while (n == 0 && !_finished) + { + n = _iter.Current.Read(buffer, offset, count); + + if (n == 0) + _finished = !_iter.MoveNext(); + } + + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } +}