diff --git a/MareSynchronos/FileCache/FileCacheManager.cs b/MareSynchronos/FileCache/FileCacheManager.cs index 4243633..43585c9 100644 --- a/MareSynchronos/FileCache/FileCacheManager.cs +++ b/MareSynchronos/FileCache/FileCacheManager.cs @@ -41,7 +41,7 @@ public sealed class FileCacheManager : IHostedService private string CsvBakPath => _csvPath + ".bak"; - public FileCacheEntity? CreateCacheEntry(string path) + public FileCacheEntity? CreateCacheEntry(string path, string? hash = null) { FileInfo fi = new(path); if (!fi.Exists) return null; @@ -49,7 +49,10 @@ public sealed class FileCacheManager : IHostedService var fullName = fi.FullName.ToLowerInvariant(); if (!fullName.Contains(_configService.Current.CacheFolder.ToLowerInvariant(), StringComparison.Ordinal)) return null; string prefixedPath = fullName.Replace(_configService.Current.CacheFolder.ToLowerInvariant(), CachePrefix + "\\", StringComparison.Ordinal).Replace("\\\\", "\\", StringComparison.Ordinal); - return CreateFileCacheEntity(fi, prefixedPath); + if (hash != null) + return CreateFileCacheEntity(fi, prefixedPath, hash); + else + return CreateFileCacheEntity(fi, prefixedPath); } public FileCacheEntity? CreateSubstEntry(string path) diff --git a/MareSynchronos/Utils/HashingStream.cs b/MareSynchronos/Utils/HashingStream.cs new file mode 100644 index 0000000..4fdcf5d --- /dev/null +++ b/MareSynchronos/Utils/HashingStream.cs @@ -0,0 +1,82 @@ +using System.Security.Cryptography; + +namespace MareSynchronosStaticFilesServer.Utils; + +// Calculates the hash of content read or written to a stream +public class HashingStream : Stream +{ + private readonly Stream _stream; + private readonly HashAlgorithm _hashAlgo; + private bool _finished = false; + public bool DisposeUnderlying = true; + + public Stream UnderlyingStream { get => _stream; } + + public HashingStream(Stream underlyingStream, HashAlgorithm hashAlgo) + { + _stream = underlyingStream; + _hashAlgo = hashAlgo; + } + + protected override void Dispose(bool disposing) + { + if (!DisposeUnderlying) + return; + if (!_finished) + _stream.Dispose(); + _hashAlgo.Dispose(); + } + + public override bool CanRead => _stream.CanRead; + public override bool CanSeek => false; + public override bool CanWrite => _stream.CanWrite; + public override long Length => _stream.Length; + + public override long Position { get => _stream.Position; set => throw new NotSupportedException(); } + + public override void Flush() + { + _stream.Flush(); + } + + public override int Read(byte[] buffer, int offset, int count) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + int n = _stream.Read(buffer, offset, count); + if (n > 0) + _hashAlgo.TransformBlock(buffer, offset, n, buffer, offset); + return n; + } + + public override long Seek(long offset, SeekOrigin origin) + { + throw new NotSupportedException(); + } + + public override void SetLength(long value) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + _stream.SetLength(value); + } + + public override void Write(byte[] buffer, int offset, int count) + { + if (_finished) + throw new ObjectDisposedException("HashingStream"); + _stream.Write(buffer, offset, count); + string x = new(System.Text.Encoding.ASCII.GetChars(buffer.AsSpan().Slice(offset, count).ToArray())); + _hashAlgo.TransformBlock(buffer, offset, count, buffer, offset); + } + + public byte[] Finish() + { + if (_finished) + return _hashAlgo.Hash; + _hashAlgo.TransformFinalBlock(Array.Empty(), 0, 0); + if (DisposeUnderlying) + _stream.Dispose(); + return _hashAlgo.Hash; + } +} diff --git a/MareSynchronos/WebAPI/Files/FileDownloadManager.cs b/MareSynchronos/WebAPI/Files/FileDownloadManager.cs index 89d9f60..7368001 100644 --- a/MareSynchronos/WebAPI/Files/FileDownloadManager.cs +++ b/MareSynchronos/WebAPI/Files/FileDownloadManager.cs @@ -8,9 +8,11 @@ using MareSynchronos.PlayerData.Handlers; using MareSynchronos.Services.Mediator; using MareSynchronos.Utils; using MareSynchronos.WebAPI.Files.Models; +using MareSynchronosStaticFilesServer.Utils; using Microsoft.Extensions.Logging; using System.Net; using System.Net.Http.Json; +using System.Security.Cryptography; namespace MareSynchronos.WebAPI.Files; @@ -333,12 +335,12 @@ public partial class FileDownloadManager : DisposableMediatorSubscriberBase tasks.Add(Task.Run(() => { try { - using var tmpFileStream = new FileStream(tmpPath, new FileStreamOptions() + using var tmpFileStream = new HashingStream(new FileStream(tmpPath, new FileStreamOptions() { Mode = FileMode.CreateNew, Access = FileAccess.Write, Share = FileShare.None - }); + }), SHA1.Create()); using var fileChunkStream = new FileStream(blockFile, new FileStreamOptions() { @@ -359,6 +361,14 @@ public partial class FileDownloadManager : DisposableMediatorSubscriberBase throw new EndOfStreamException(); } + string calculatedHash = BitConverter.ToString(tmpFileStream.Finish()).Replace("-", "", StringComparison.Ordinal); + + if (calculatedHash != fileHash) + { + Logger.LogError("Hash mismatch after extracting, got {hash}, expected {expectedHash}, deleting file", calculatedHash, fileHash); + return; + } + tmpFileStream.Close(); _fileCompactor.RenameAndCompact(filePath, tmpPath); PersistFileToStorage(fileHash, filePath, fileLengthBytes); @@ -418,11 +428,9 @@ public partial class FileDownloadManager : DisposableMediatorSubscriberBase { try { - var entry = _fileDbManager.CreateCacheEntry(filePath); + var entry = _fileDbManager.CreateCacheEntry(filePath, fileHash); if (entry != null && !string.Equals(entry.Hash, fileHash, StringComparison.OrdinalIgnoreCase)) { - Logger.LogError("Hash mismatch after extracting, got {hash}, expected {expectedHash}, deleting file", entry.Hash, fileHash); - File.Delete(filePath); _fileDbManager.RemoveHashedFile(entry.Hash, entry.PrefixedFilePath); entry = null; }