diff --git a/MareSynchronosServer/MareSynchronosServer/MareSynchronosServer.csproj b/MareSynchronosServer/MareSynchronosServer/MareSynchronosServer.csproj index 0271300..9d1cdbb 100644 --- a/MareSynchronosServer/MareSynchronosServer/MareSynchronosServer.csproj +++ b/MareSynchronosServer/MareSynchronosServer/MareSynchronosServer.csproj @@ -27,7 +27,6 @@ all runtime; build; native; contentfiles; analyzers; buildtransitive - all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/ServerFilesController.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/ServerFilesController.cs index fb09896..5f2630b 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/ServerFilesController.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Controllers/ServerFilesController.cs @@ -1,4 +1,4 @@ -using LZ4; +using K4os.Compression.LZ4.Streams; using MareSynchronos.API.Dto.Files; using MareSynchronos.API.Routes; using MareSynchronos.API.SignalR; @@ -183,115 +183,68 @@ public class ServerFilesController : ControllerBase return Ok(); } - // copy the request body to memory - using var compressedFileStream = new MemoryStream(); - await Request.Body.CopyToAsync(compressedFileStream, requestAborted).ConfigureAwait(false); - - // decompress and copy the decompressed stream to memory - var data = LZ4Codec.Unwrap(compressedFileStream.ToArray()); - - // reset streams - compressedFileStream.Seek(0, SeekOrigin.Begin); - - // compute hash to verify - var hashString = BitConverter.ToString(SHA1.HashData(data)) - .Replace("-", "", StringComparison.Ordinal).ToUpperInvariant(); - if (!string.Equals(hashString, hash, StringComparison.Ordinal)) - throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}"); - - // save file var path = FilePathUtil.GetFilePath(_basePath, hash); - using var fileStream = new FileStream(path, FileMode.Create); - await compressedFileStream.CopyToAsync(fileStream).ConfigureAwait(false); + var tmpPath = path + ".tmp"; + long compressedSize = -1; - // update on db - await _mareDbContext.Files.AddAsync(new FileCache() + try { - Hash = hash, - UploadDate = DateTime.UtcNow, - UploaderUID = MareUser, - Size = compressedFileStream.Length, - Uploaded = true - }).ConfigureAwait(false); - await _mareDbContext.SaveChangesAsync().ConfigureAwait(false); + // Write incoming file to a temporary file while also hashing the decompressed content - _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotal, 1); - _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, compressedFileStream.Length); + // Stream flow diagram: + // Request.Body ==> (Tee) ==> FileStream + // ==> CountedStream ==> LZ4DecoderStream ==> HashingStream ==> Stream.Null - _fileUploadLocks.TryRemove(hash, out _); + // Reading via TeeStream causes the request body to be copied to tmpPath + using var tmpFileStream = new FileStream(tmpPath, FileMode.Create); + using var teeStream = new TeeStream(Request.Body, tmpFileStream); + teeStream.DisposeUnderlying = false; + // Read via CountedStream to count the number of compressed bytes + using var countStream = new CountedStream(teeStream); + countStream.DisposeUnderlying = false; - return Ok(); - } - catch (Exception e) - { - _logger.LogError(e, "Error during file upload"); - return BadRequest(); - } - finally - { - fileLock.Release(); - } - } + // The decompressed file content is read through LZ4DecoderStream, and written out to HashingStream + using var decStream = LZ4Stream.Decode(countStream, extraMemory: 0, leaveOpen: true); + // HashingStream simply hashes the decompressed bytes without writing them anywhere + using var hashStream = new HashingStream(Stream.Null, SHA1.Create()); + hashStream.DisposeUnderlying = false; - [HttpPost(MareFiles.ServerFiles_UploadRaw + "/{hash}")] - [RequestSizeLimit(200 * 1024 * 1024)] - public async Task UploadFileRaw(string hash, CancellationToken requestAborted) - { - _logger.LogInformation("{user} uploading raw file {file}", MareUser, hash); - hash = hash.ToUpperInvariant(); - var existingFile = await _mareDbContext.Files.SingleOrDefaultAsync(f => f.Hash == hash); - if (existingFile != null) return Ok(); + await decStream.CopyToAsync(hashStream, requestAborted).ConfigureAwait(false); + decStream.Close(); - SemaphoreSlim fileLock; - lock (_fileUploadLocks) - { - if (!_fileUploadLocks.TryGetValue(hash, out fileLock)) - _fileUploadLocks[hash] = fileLock = new SemaphoreSlim(1); - } + var hashString = BitConverter.ToString(hashStream.Finish()) + .Replace("-", "", StringComparison.Ordinal).ToUpperInvariant(); + if (!string.Equals(hashString, hash, StringComparison.Ordinal)) + throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}"); - await fileLock.WaitAsync(requestAborted).ConfigureAwait(false); + compressedSize = countStream.BytesRead; - try - { - var existingFileCheck2 = await _mareDbContext.Files.SingleOrDefaultAsync(f => f.Hash == hash); - if (existingFileCheck2 != null) + // File content is verified -- move it to its final location + System.IO.File.Move(tmpPath, path, true); + } + catch { - return Ok(); + try + { + System.IO.File.Delete(tmpPath); + } + catch { } + throw; } - // copy the request body to memory - using var rawFileStream = new MemoryStream(); - await Request.Body.CopyToAsync(rawFileStream, requestAborted).ConfigureAwait(false); - - // reset streams - rawFileStream.Seek(0, SeekOrigin.Begin); - - // compute hash to verify - var hashString = BitConverter.ToString(SHA1.HashData(rawFileStream.ToArray())) - .Replace("-", "", StringComparison.Ordinal).ToUpperInvariant(); - if (!string.Equals(hashString, hash, StringComparison.Ordinal)) - throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}"); - - // save file - var path = FilePathUtil.GetFilePath(_basePath, hash); - using var fileStream = new FileStream(path, FileMode.Create); - var lz4 = LZ4Codec.WrapHC(rawFileStream.ToArray(), 0, (int)rawFileStream.Length); - using var compressedStream = new MemoryStream(lz4); - await compressedStream.CopyToAsync(fileStream).ConfigureAwait(false); - // update on db await _mareDbContext.Files.AddAsync(new FileCache() { Hash = hash, UploadDate = DateTime.UtcNow, UploaderUID = MareUser, - Size = compressedStream.Length, + Size = compressedSize, Uploaded = true }).ConfigureAwait(false); await _mareDbContext.SaveChangesAsync().ConfigureAwait(false); _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotal, 1); - _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, rawFileStream.Length); + _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, compressedSize); _fileUploadLocks.TryRemove(hash, out _); diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/MareSynchronosStaticFilesServer.csproj b/MareSynchronosServer/MareSynchronosStaticFilesServer/MareSynchronosStaticFilesServer.csproj index ab914c0..94e5a0c 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/MareSynchronosStaticFilesServer.csproj +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/MareSynchronosStaticFilesServer.csproj @@ -18,7 +18,7 @@ - + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/ConcatenatedStreamReader.cs similarity index 55% rename from MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs rename to MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/ConcatenatedStreamReader.cs index d85ef4f..3402444 100644 --- a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/StreamUtils.cs +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/ConcatenatedStreamReader.cs @@ -1,60 +1,12 @@ namespace MareSynchronosStaticFilesServer.Utils; -public class CountedStream : Stream -{ - private readonly Stream _stream; - public ulong BytesRead { get; private set; } - public ulong BytesWritten { get; private set; } - - public CountedStream(Stream underlyingStream) - { - _stream = underlyingStream; - } - - public override bool CanRead => _stream.CanRead; - - public override bool CanSeek => _stream.CanSeek; - - public override bool CanWrite => _stream.CanWrite; - - public override long Length => _stream.Length; - - public override long Position { get => _stream.Position; set => _stream.Position = value; } - - public override void Flush() - { - _stream.Flush(); - } - - public override int Read(byte[] buffer, int offset, int count) - { - int n = _stream.Read(buffer, offset, count); - BytesRead += (ulong)n; - return n; - } - - public override long Seek(long offset, SeekOrigin origin) - { - return _stream.Seek(offset, origin); - } - - public override void SetLength(long value) - { - _stream.SetLength(value); - } - - public override void Write(byte[] buffer, int offset, int count) - { - BytesWritten += (ulong)count; - _stream.Write(buffer, offset, count); - } -} - +// Concatenates the content of multiple readable streams public class ConcatenatedStreamReader : Stream { private IEnumerable _streams; private IEnumerator _iter; private bool _finished; + public bool DisposeUnderlying = true; public ConcatenatedStreamReader(IEnumerable streams) { @@ -63,12 +15,17 @@ public class ConcatenatedStreamReader : Stream _finished = !_iter.MoveNext(); } + protected override void Dispose(bool disposing) + { + if (!DisposeUnderlying) + return; + foreach (var stream in _streams) + stream.Dispose(); + } + public override bool CanRead => true; - public override bool CanSeek => false; - public override bool CanWrite => false; - public override long Length => throw new NotSupportedException(); public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } @@ -92,6 +49,24 @@ public class ConcatenatedStreamReader : Stream return n; } + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int n = 0; + + while (n == 0 && !_finished) + { + n = await _iter.Current.ReadAsync(buffer, offset, count, cancellationToken); + + if (cancellationToken.IsCancellationRequested) + break; + + if (n == 0) + _finished = !_iter.MoveNext(); + } + + return n; + } + public override long Seek(long offset, SeekOrigin origin) { throw new NotSupportedException(); diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/CountedStream.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/CountedStream.cs new file mode 100644 index 0000000..d5412c5 --- /dev/null +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/CountedStream.cs @@ -0,0 +1,73 @@ + +namespace MareSynchronosStaticFilesServer.Utils; + +// Counts the number of bytes read/written to an underlying stream +public class CountedStream : Stream +{ + private readonly Stream _stream; + public long BytesRead { get; private set; } + public long BytesWritten { get; private set; } + public bool DisposeUnderlying = true; + + public Stream UnderlyingStream { get => _stream; } + + public CountedStream(Stream underlyingStream) + { + _stream = underlyingStream; + } + + protected override void Dispose(bool disposing) + { + if (!DisposeUnderlying) + return; + _stream.Dispose(); + } + + public override bool CanRead => _stream.CanRead; + public override bool CanSeek => _stream.CanSeek; + public override bool CanWrite => _stream.CanWrite; + public override long Length => _stream.Length; + + public override long Position { get => _stream.Position; set => _stream.Position = value; } + + public override void Flush() + { + _stream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = _stream.Read(buffer, offset, count); + BytesRead += n; + return n; + } + + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int n = await _stream.ReadAsync(buffer, offset, count, cancellationToken); + BytesRead += n; + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _stream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _stream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + _stream.Write(buffer, offset, count); + BytesWritten += count; + } + + public async override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + await _stream.WriteAsync(buffer, offset, count, cancellationToken); + BytesWritten += count; + } +} diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/HashingStream.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/HashingStream.cs new file mode 100644 index 0000000..4fdcf5d --- /dev/null +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/HashingStream.cs @@ -0,0 +1,82 @@ +using System.Security.Cryptography; + +namespace MareSynchronosStaticFilesServer.Utils; + +// Calculates the hash of content read or written to a stream +public class HashingStream : Stream +{ + private readonly Stream _stream; + private readonly HashAlgorithm _hashAlgo; + private bool _finished = false; + public bool DisposeUnderlying = true; + + public Stream UnderlyingStream { get => _stream; } + + public HashingStream(Stream underlyingStream, HashAlgorithm hashAlgo) + { + _stream = underlyingStream; + _hashAlgo = hashAlgo; + } + + protected override void Dispose(bool disposing) + { + if (!DisposeUnderlying) + return; + if (!_finished) + _stream.Dispose(); + _hashAlgo.Dispose(); + } + + public override bool CanRead => _stream.CanRead; + public override bool CanSeek => false; + public override bool CanWrite => _stream.CanWrite; + public override long Length => _stream.Length; + + public override long Position { get => _stream.Position; set => throw new NotSupportedException(); } + + public override void Flush() + { + _stream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + int n = _stream.Read(buffer, offset, count); + if (n > 0) + _hashAlgo.TransformBlock(buffer, offset, n, buffer, offset); + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + _stream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + _stream.Write(buffer, offset, count); + string x = new(System.Text.Encoding.ASCII.GetChars(buffer.AsSpan().Slice(offset, count).ToArray())); + _hashAlgo.TransformBlock(buffer, offset, count, buffer, offset); + } + + public byte[] Finish() + { + if (_finished) + return _hashAlgo.Hash; + _hashAlgo.TransformFinalBlock(Array.Empty(), 0, 0); + if (DisposeUnderlying) + _stream.Dispose(); + return _hashAlgo.Hash; + } +} diff --git a/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/TeeStream.cs b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/TeeStream.cs new file mode 100644 index 0000000..aed9355 --- /dev/null +++ b/MareSynchronosServer/MareSynchronosStaticFilesServer/Utils/TeeStream.cs @@ -0,0 +1,74 @@ +namespace MareSynchronosStaticFilesServer.Utils; + +// Writes data read from one stream out to a second stream +public class TeeStream : Stream +{ + private readonly Stream _inStream; + private readonly Stream _outStream; + public bool DisposeUnderlying = true; + + public Stream InStream { get => _inStream; } + public Stream OutStream { get => _outStream; } + + public TeeStream(Stream inStream, Stream outStream) + { + _inStream = inStream; + _outStream = outStream; + } + + protected override void Dispose(bool disposing) + { + if (!DisposeUnderlying) + return; + _inStream.Dispose(); + _outStream.Dispose(); + } + + public override bool CanRead => _inStream.CanRead; + public override bool CanSeek => _inStream.CanSeek; + public override bool CanWrite => false; + public override long Length => _inStream.Length; + + public override long Position + { + get => _inStream.Position; + set => _inStream.Position = value; + } + + public override void Flush() + { + _inStream.Flush(); + _outStream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + int n = _inStream.Read(buffer, offset, count); + if (n > 0) + _outStream.Write(buffer, offset, n); + return n; + } + + public async override Task ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken) + { + int n = await _inStream.ReadAsync(buffer, offset, count, cancellationToken); + if (n > 0) + await _outStream.WriteAsync(buffer, offset, n); + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + return _inStream.Seek(offset, origin); + } + + public override void SetLength(long value) + { + _inStream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + throw new NotSupportedException(); + } +}