refactor server auth on files server, add checking request queue

This commit is contained in:
rootdarkarchon
2023-01-18 10:20:24 +01:00
parent 20d8970a15
commit 9b4e298b66
6 changed files with 38 additions and 21 deletions

View File

@@ -4,6 +4,7 @@ using MareSynchronosStaticFilesServer.Utils;
using System.Collections.Concurrent;
using System.Net.Http.Headers;
using MareSynchronos.API;
using MareSynchronosShared.Utils;
namespace MareSynchronosStaticFilesServer.Services;
@@ -12,30 +13,32 @@ public class CachedFileProvider
private readonly ILogger<CachedFileProvider> _logger;
private readonly FileStatisticsService _fileStatisticsService;
private readonly MareMetrics _metrics;
private readonly ServerTokenGenerator _generator;
private readonly Uri _remoteCacheSourceUri;
private readonly string _basePath;
private readonly ConcurrentDictionary<string, Task> _currentTransfers = new(StringComparer.Ordinal);
private readonly HttpClient _httpClient;
private bool IsMainServer => _remoteCacheSourceUri == null;
public CachedFileProvider(IConfigurationService<StaticFilesServerConfiguration> configuration, ILogger<CachedFileProvider> logger, FileStatisticsService fileStatisticsService, MareMetrics metrics)
public CachedFileProvider(IConfigurationService<StaticFilesServerConfiguration> configuration, ILogger<CachedFileProvider> logger, FileStatisticsService fileStatisticsService, MareMetrics metrics, ServerTokenGenerator generator)
{
_logger = logger;
_fileStatisticsService = fileStatisticsService;
_metrics = metrics;
_generator = generator;
_remoteCacheSourceUri = configuration.GetValueOrDefault<Uri>(nameof(StaticFilesServerConfiguration.RemoteCacheSourceUri), null);
_basePath = configuration.GetValue<string>(nameof(StaticFilesServerConfiguration.CacheDirectory));
_httpClient = new HttpClient();
}
private async Task DownloadTask(string hash, string auth)
private async Task DownloadTask(string hash)
{
// download file from remote
var downloadUrl = MareFiles.ServerFilesGetFullPath(_remoteCacheSourceUri, hash);
_logger.LogInformation("Did not find {hash}, downloading from {server}", hash, downloadUrl);
using var requestMessage = new HttpRequestMessage(HttpMethod.Get, downloadUrl);
requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", auth);
requestMessage.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _generator.Token);
var response = await _httpClient.SendAsync(requestMessage).ConfigureAwait(false);
try
@@ -63,7 +66,7 @@ public class CachedFileProvider
_metrics.IncGauge(MetricsAPI.GaugeFilesTotalSize, FilePathUtil.GetFileInfoForHash(_basePath, hash).Length);
}
public void DownloadFileWhenRequired(string hash, string auth)
public void DownloadFileWhenRequired(string hash)
{
var fi = FilePathUtil.GetFileInfoForHash(_basePath, hash);
if (fi == null && IsMainServer) return;
@@ -72,7 +75,7 @@ public class CachedFileProvider
{
_currentTransfers[hash] = Task.Run(async () =>
{
await DownloadTask(hash, auth).ConfigureAwait(false);
await DownloadTask(hash).ConfigureAwait(false);
_currentTransfers.Remove(hash, out _);
});
}
@@ -88,9 +91,9 @@ public class CachedFileProvider
return new FileStream(fi.FullName, FileMode.Open, FileAccess.Read, FileShare.Inheritable | FileShare.Read);
}
public async Task<FileStream?> GetAndDownloadFileStream(string hash, string auth)
public async Task<FileStream?> GetAndDownloadFileStream(string hash)
{
DownloadFileWhenRequired(hash, auth);
DownloadFileWhenRequired(hash);
if (_currentTransfers.TryGetValue(hash, out var downloadTask))
{