diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs index 580b6a4..dc3bb0d 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/CacheController.cs @@ -34,27 +34,22 @@ public class CacheController : ControllerBase _requestQueue.ActivateRequest(requestId); Response.ContentType = "application/octet-stream"; - var memoryStream = new MemoryStream(); - var streamWriter = new BinaryWriter(memoryStream); long requestSize = 0; + var streamList = new List(); foreach (var file in request.FileIds) { var fs = await _cachedFileProvider.GetAndDownloadFileStream(file); if (fs == null) continue; - streamWriter.Write(Encoding.ASCII.GetBytes("#" + file + ":" + fs.Length.ToString(CultureInfo.InvariantCulture) + "#")); - byte[] buffer = new byte[fs.Length]; - _ = await fs.ReadAsync(buffer, HttpContext.RequestAborted); - streamWriter.Write(buffer); + var headerBytes = Encoding.ASCII.GetBytes("#" + file + ":" + fs.Length.ToString(CultureInfo.InvariantCulture) + "#"); + streamList.Add(new MemoryStream(headerBytes)); + streamList.Add(fs); requestSize += fs.Length; } - streamWriter.Flush(); - memoryStream.Position = 0; - _fileStatisticsService.LogRequest(requestSize); - return _requestFileStreamResultFactory.Create(requestId, memoryStream); + return _requestFileStreamResultFactory.Create(requestId, new ConcatenatedStreamReader(streamList)); } } diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs index e1e05ac..43a952d 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/RequestFileStreamResultFactory.cs @@ -17,9 +17,9 @@ public class RequestFileStreamResultFactory _configurationService = configurationService; } - public RequestFileStreamResult Create(Guid requestId, MemoryStream ms) + public RequestFileStreamResult Create(Guid requestId, Stream stream) { return new RequestFileStreamResult(requestId, _requestQueueService, - _metrics, ms, "application/octet-stream"); + _metrics, stream, "application/octet-stream"); } } \ No newline at end of file diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs new file mode 100644 index 0000000..d85ef4f --- /dev/null +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs @@ -0,0 +1,109 @@ +namespace MareSynchronosStaticFilesServer.Utils; + +public class CountedStream : Stream +{ + private readonly Stream _stream; + public ulong BytesRead { get; private set; } + public ulong BytesWritten { get; private set; } + + public CountedStream(Stream underlyingStream) + { + _stream = underlyingStream; + } + + public override bool CanRead => _stream.CanRead; + + public override bool CanSeek => _stream.CanSeek; + + public override bool CanWrite => _stream.CanWrite; + + public override long Length => _stream.Length; + + public override long Position { get => _stream.Position; set => _stream.Position = value; } + + public override void Flush() + { + _stream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = _stream.Read(buffer, offset, count); + BytesRead += (ulong)n; + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _stream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _stream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + BytesWritten += (ulong)count; + _stream.Write(buffer, offset, count); + } +} + +public class ConcatenatedStreamReader : Stream +{ + private IEnumerable _streams; + private IEnumerator _iter; + private bool _finished; + + public ConcatenatedStreamReader(IEnumerable streams) + { + _streams = streams; + _iter = streams.GetEnumerator(); + _finished = !_iter.MoveNext(); + } + + public override bool CanRead => true; + + public override bool CanSeek => false; + + public override bool CanWrite => false; + + public override long Length => throw new NotSupportedException(); + + public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } + + public override void Flush() + { + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = 0; + + while (n == 0 && !_finished) + { + n = _iter.Current.Read(buffer, offset, count); + + if (n == 0) + _finished = !_iter.MoveNext(); + } + + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + throw new NotSupportedException(); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } +}