Use streamable compression (needs file cache clear)

This commit is contained in:
Loporrit
2023-12-18 12:51:36 +00:00
parent 0707d3eb54
commit 8cf4f50091
7 changed files with 297 additions and 141 deletions

View File

@@ -27,7 +27,6 @@
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference> </PackageReference>
<PackageReference Include="lz4net" Version="1.0.15.93" />
<PackageReference Include="Meziantou.Analyzer" Version="2.0.49"> <PackageReference Include="Meziantou.Analyzer" Version="2.0.49">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

View File

@@ -1,4 +1,4 @@
using LZ4; using K4os.Compression.LZ4.Streams;
using MareSynchronos.API.Dto.Files; using MareSynchronos.API.Dto.Files;
using MareSynchronos.API.Routes; using MareSynchronos.API.Routes;
using MareSynchronos.API.SignalR; using MareSynchronos.API.SignalR;
@@ -183,115 +183,68 @@ public class ServerFilesController : ControllerBase
return Ok(); return Ok();
} }
// copy the request body to memory
using var compressedFileStream = new MemoryStream();
await Request.Body.CopyToAsync(compressedFileStream, requestAborted).ConfigureAwait(false);
// decompress and copy the decompressed stream to memory
var data = LZ4Codec.Unwrap(compressedFileStream.ToArray());
// reset streams
compressedFileStream.Seek(0, SeekOrigin.Begin);
// compute hash to verify
var hashString = BitConverter.ToString(SHA1.HashData(data))
.Replace("-", "", StringComparison.Ordinal).ToUpperInvariant();
if (!string.Equals(hashString, hash, StringComparison.Ordinal))
throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}");
// save file
var path = FilePathUtil.GetFilePath(_basePath, hash); var path = FilePathUtil.GetFilePath(_basePath, hash);
using var fileStream = new FileStream(path, FileMode.Create); var tmpPath = path + ".tmp";
await compressedFileStream.CopyToAsync(fileStream).ConfigureAwait(false); long compressedSize = -1;
// update on db try
await _mareDbContext.Files.AddAsync(new FileCache()
{ {
Hash = hash, // Write incoming file to a temporary file while also hashing the decompressed content
UploadDate = DateTime.UtcNow,
UploaderUID = MareUser,
Size = compressedFileStream.Length,
Uploaded = true
}).ConfigureAwait(false);
await _mareDbContext.SaveChangesAsync().ConfigureAwait(false);
_metricsClient.IncGauge(MetricsAPI.GaugeFilesTotal, 1); // Stream flow diagram:
_metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, compressedFileStream.Length); // Request.Body ==> (Tee) ==> FileStream
// ==> CountedStream ==> LZ4DecoderStream ==> HashingStream ==> Stream.Null
_fileUploadLocks.TryRemove(hash, out _); // Reading via TeeStream causes the request body to be copied to tmpPath
using var tmpFileStream = new FileStream(tmpPath, FileMode.Create);
using var teeStream = new TeeStream(Request.Body, tmpFileStream);
teeStream.DisposeUnderlying = false;
// Read via CountedStream to count the number of compressed bytes
using var countStream = new CountedStream(teeStream);
countStream.DisposeUnderlying = false;
return Ok(); // The decompressed file content is read through LZ4DecoderStream, and written out to HashingStream
} using var decStream = LZ4Stream.Decode(countStream, extraMemory: 0, leaveOpen: true);
catch (Exception e) // HashingStream simply hashes the decompressed bytes without writing them anywhere
{ using var hashStream = new HashingStream(Stream.Null, SHA1.Create());
_logger.LogError(e, "Error during file upload"); hashStream.DisposeUnderlying = false;
return BadRequest();
}
finally
{
fileLock.Release();
}
}
[HttpPost(MareFiles.ServerFiles_UploadRaw + "/{hash}")] await decStream.CopyToAsync(hashStream, requestAborted).ConfigureAwait(false);
[RequestSizeLimit(200 * 1024 * 1024)] decStream.Close();
public async Task<IActionResult> UploadFileRaw(string hash, CancellationToken requestAborted)
{
_logger.LogInformation("{user} uploading raw file {file}", MareUser, hash);
hash = hash.ToUpperInvariant();
var existingFile = await _mareDbContext.Files.SingleOrDefaultAsync(f => f.Hash == hash);
if (existingFile != null) return Ok();
SemaphoreSlim fileLock; var hashString = BitConverter.ToString(hashStream.Finish())
lock (_fileUploadLocks) .Replace("-", "", StringComparison.Ordinal).ToUpperInvariant();
{ if (!string.Equals(hashString, hash, StringComparison.Ordinal))
if (!_fileUploadLocks.TryGetValue(hash, out fileLock)) throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}");
_fileUploadLocks[hash] = fileLock = new SemaphoreSlim(1);
}
await fileLock.WaitAsync(requestAborted).ConfigureAwait(false); compressedSize = countStream.BytesRead;
try // File content is verified -- move it to its final location
{ System.IO.File.Move(tmpPath, path, true);
var existingFileCheck2 = await _mareDbContext.Files.SingleOrDefaultAsync(f => f.Hash == hash); }
if (existingFileCheck2 != null) catch
{ {
return Ok(); try
{
System.IO.File.Delete(tmpPath);
}
catch { }
throw;
} }
// copy the request body to memory
using var rawFileStream = new MemoryStream();
await Request.Body.CopyToAsync(rawFileStream, requestAborted).ConfigureAwait(false);
// reset streams
rawFileStream.Seek(0, SeekOrigin.Begin);
// compute hash to verify
var hashString = BitConverter.ToString(SHA1.HashData(rawFileStream.ToArray()))
.Replace("-", "", StringComparison.Ordinal).ToUpperInvariant();
if (!string.Equals(hashString, hash, StringComparison.Ordinal))
throw new InvalidOperationException($"Hash does not match file, computed: {hashString}, expected: {hash}");
// save file
var path = FilePathUtil.GetFilePath(_basePath, hash);
using var fileStream = new FileStream(path, FileMode.Create);
var lz4 = LZ4Codec.WrapHC(rawFileStream.ToArray(), 0, (int)rawFileStream.Length);
using var compressedStream = new MemoryStream(lz4);
await compressedStream.CopyToAsync(fileStream).ConfigureAwait(false);
// update on db // update on db
await _mareDbContext.Files.AddAsync(new FileCache() await _mareDbContext.Files.AddAsync(new FileCache()
{ {
Hash = hash, Hash = hash,
UploadDate = DateTime.UtcNow, UploadDate = DateTime.UtcNow,
UploaderUID = MareUser, UploaderUID = MareUser,
Size = compressedStream.Length, Size = compressedSize,
Uploaded = true Uploaded = true
}).ConfigureAwait(false); }).ConfigureAwait(false);
await _mareDbContext.SaveChangesAsync().ConfigureAwait(false); await _mareDbContext.SaveChangesAsync().ConfigureAwait(false);
_metricsClient.IncGauge(MetricsAPI.GaugeFilesTotal, 1); _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotal, 1);
_metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, rawFileStream.Length); _metricsClient.IncGauge(MetricsAPI.GaugeFilesTotalSize, compressedSize);
_fileUploadLocks.TryRemove(hash, out _); _fileUploadLocks.TryRemove(hash, out _);

View File

@@ -18,7 +18,7 @@
</ItemGroup> </ItemGroup>
<ItemGroup> <ItemGroup>
<PackageReference Include="lz4net" Version="1.0.15.93" /> <PackageReference Include="K4os.Compression.LZ4.Streams" Version="1.3.6" />
<PackageReference Include="Meziantou.Analyzer" Version="2.0.49"> <PackageReference Include="Meziantou.Analyzer" Version="2.0.49">
<PrivateAssets>all</PrivateAssets> <PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets> <IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

View File

@@ -1,60 +1,12 @@
namespace MareSynchronosStaticFilesServer.Utils; namespace MareSynchronosStaticFilesServer.Utils;
public class CountedStream : Stream // Concatenates the content of multiple readable streams
{
private readonly Stream _stream;
public ulong BytesRead { get; private set; }
public ulong BytesWritten { get; private set; }
public CountedStream(Stream underlyingStream)
{
_stream = underlyingStream;
}
public override bool CanRead => _stream.CanRead;
public override bool CanSeek => _stream.CanSeek;
public override bool CanWrite => _stream.CanWrite;
public override long Length => _stream.Length;
public override long Position { get => _stream.Position; set => _stream.Position = value; }
public override void Flush()
{
_stream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
int n = _stream.Read(buffer, offset, count);
BytesRead += (ulong)n;
return n;
}
public override long Seek(long offset, SeekOrigin origin)
{
return _stream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_stream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
BytesWritten += (ulong)count;
_stream.Write(buffer, offset, count);
}
}
public class ConcatenatedStreamReader : Stream public class ConcatenatedStreamReader : Stream
{ {
private IEnumerable<Stream> _streams; private IEnumerable<Stream> _streams;
private IEnumerator<Stream> _iter; private IEnumerator<Stream> _iter;
private bool _finished; private bool _finished;
public bool DisposeUnderlying = true;
public ConcatenatedStreamReader(IEnumerable<Stream> streams) public ConcatenatedStreamReader(IEnumerable<Stream> streams)
{ {
@@ -63,12 +15,17 @@ public class ConcatenatedStreamReader : Stream
_finished = !_iter.MoveNext(); _finished = !_iter.MoveNext();
} }
protected override void Dispose(bool disposing)
{
if (!DisposeUnderlying)
return;
foreach (var stream in _streams)
stream.Dispose();
}
public override bool CanRead => true; public override bool CanRead => true;
public override bool CanSeek => false; public override bool CanSeek => false;
public override bool CanWrite => false; public override bool CanWrite => false;
public override long Length => throw new NotSupportedException(); public override long Length => throw new NotSupportedException();
public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); } public override long Position { get => throw new NotSupportedException(); set => throw new NotSupportedException(); }
@@ -92,6 +49,24 @@ public class ConcatenatedStreamReader : Stream
return n; return n;
} }
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
int n = 0;
while (n == 0 && !_finished)
{
n = await _iter.Current.ReadAsync(buffer, offset, count, cancellationToken);
if (cancellationToken.IsCancellationRequested)
break;
if (n == 0)
_finished = !_iter.MoveNext();
}
return n;
}
public override long Seek(long offset, SeekOrigin origin) public override long Seek(long offset, SeekOrigin origin)
{ {
throw new NotSupportedException(); throw new NotSupportedException();

View File

@@ -0,0 +1,73 @@
namespace MareSynchronosStaticFilesServer.Utils;
// Counts the number of bytes read/written to an underlying stream
public class CountedStream : Stream
{
private readonly Stream _stream;
public long BytesRead { get; private set; }
public long BytesWritten { get; private set; }
public bool DisposeUnderlying = true;
public Stream UnderlyingStream { get => _stream; }
public CountedStream(Stream underlyingStream)
{
_stream = underlyingStream;
}
protected override void Dispose(bool disposing)
{
if (!DisposeUnderlying)
return;
_stream.Dispose();
}
public override bool CanRead => _stream.CanRead;
public override bool CanSeek => _stream.CanSeek;
public override bool CanWrite => _stream.CanWrite;
public override long Length => _stream.Length;
public override long Position { get => _stream.Position; set => _stream.Position = value; }
public override void Flush()
{
_stream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
int n = _stream.Read(buffer, offset, count);
BytesRead += n;
return n;
}
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
int n = await _stream.ReadAsync(buffer, offset, count, cancellationToken);
BytesRead += n;
return n;
}
public override long Seek(long offset, SeekOrigin origin)
{
return _stream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_stream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
_stream.Write(buffer, offset, count);
BytesWritten += count;
}
public async override Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
await _stream.WriteAsync(buffer, offset, count, cancellationToken);
BytesWritten += count;
}
}

View File

@@ -0,0 +1,82 @@
using System.Security.Cryptography;
namespace MareSynchronosStaticFilesServer.Utils;
// Calculates the hash of content read or written to a stream
public class HashingStream : Stream
{
private readonly Stream _stream;
private readonly HashAlgorithm _hashAlgo;
private bool _finished = false;
public bool DisposeUnderlying = true;
public Stream UnderlyingStream { get => _stream; }
public HashingStream(Stream underlyingStream, HashAlgorithm hashAlgo)
{
_stream = underlyingStream;
_hashAlgo = hashAlgo;
}
protected override void Dispose(bool disposing)
{
if (!DisposeUnderlying)
return;
if (!_finished)
_stream.Dispose();
_hashAlgo.Dispose();
}
public override bool CanRead => _stream.CanRead;
public override bool CanSeek => false;
public override bool CanWrite => _stream.CanWrite;
public override long Length => _stream.Length;
public override long Position { get => _stream.Position; set => throw new NotSupportedException(); }
public override void Flush()
{
_stream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
if (_finished)
throw new ObjectDisposedException("HashingStream");
int n = _stream.Read(buffer, offset, count);
if (n > 0)
_hashAlgo.TransformBlock(buffer, offset, n, buffer, offset);
return n;
}
public override long Seek(long offset, SeekOrigin origin)
{
throw new NotSupportedException();
}
public override void SetLength(long value)
{
if (_finished)
throw new ObjectDisposedException("HashingStream");
_stream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
if (_finished)
throw new ObjectDisposedException("HashingStream");
_stream.Write(buffer, offset, count);
string x = new(System.Text.Encoding.ASCII.GetChars(buffer.AsSpan().Slice(offset, count).ToArray()));
_hashAlgo.TransformBlock(buffer, offset, count, buffer, offset);
}
public byte[] Finish()
{
if (_finished)
return _hashAlgo.Hash;
_hashAlgo.TransformFinalBlock(Array.Empty<byte>(), 0, 0);
if (DisposeUnderlying)
_stream.Dispose();
return _hashAlgo.Hash;
}
}

View File

@@ -0,0 +1,74 @@
namespace MareSynchronosStaticFilesServer.Utils;
// Writes data read from one stream out to a second stream
public class TeeStream : Stream
{
private readonly Stream _inStream;
private readonly Stream _outStream;
public bool DisposeUnderlying = true;
public Stream InStream { get => _inStream; }
public Stream OutStream { get => _outStream; }
public TeeStream(Stream inStream, Stream outStream)
{
_inStream = inStream;
_outStream = outStream;
}
protected override void Dispose(bool disposing)
{
if (!DisposeUnderlying)
return;
_inStream.Dispose();
_outStream.Dispose();
}
public override bool CanRead => _inStream.CanRead;
public override bool CanSeek => _inStream.CanSeek;
public override bool CanWrite => false;
public override long Length => _inStream.Length;
public override long Position
{
get => _inStream.Position;
set => _inStream.Position = value;
}
public override void Flush()
{
_inStream.Flush();
_outStream.Flush();
}
public override int Read(byte[] buffer, int offset, int count)
{
int n = _inStream.Read(buffer, offset, count);
if (n > 0)
_outStream.Write(buffer, offset, n);
return n;
}
public async override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
{
int n = await _inStream.ReadAsync(buffer, offset, count, cancellationToken);
if (n > 0)
await _outStream.WriteAsync(buffer, offset, n);
return n;
}
public override long Seek(long offset, SeekOrigin origin)
{
return _inStream.Seek(offset, origin);
}
public override void SetLength(long value)
{
_inStream.SetLength(value);
}
public override void Write(byte[] buffer, int offset, int count)
{
throw new NotSupportedException();
}
}