namespace MareSynchronos.WebAPI.Files
{
    /// 
    ///     Class for streaming data with throttling support.
    ///     Borrowed from https://github.com/bezzad/Downloader
    /// 
    internal class ThrottledStream : Stream
    {
        public static long Infinite => long.MaxValue;
        private readonly Stream _baseStream;
        private long _bandwidthLimit;
        private Bandwidth _bandwidth;
        private CancellationTokenSource _bandwidthChangeTokenSource = new CancellationTokenSource();
        /// 
        ///     Initializes a new instance of the  class.
        /// 
        /// The base stream.
        /// The maximum bytes per second that can be transferred through the base stream.
        /// Thrown when  is a null reference.
        /// Thrown when  is a negative value.
        public ThrottledStream(Stream baseStream, long bandwidthLimit)
        {
            if (bandwidthLimit < 0)
            {
                throw new ArgumentOutOfRangeException(nameof(bandwidthLimit),
                    bandwidthLimit, "The maximum number of bytes per second can't be negative.");
            }
            _baseStream = baseStream ?? throw new ArgumentNullException(nameof(baseStream));
            BandwidthLimit = bandwidthLimit;
        }
        /// 
        ///     Bandwidth Limit (in B/s)
        /// 
        /// The maximum bytes per second.
        public long BandwidthLimit
        {
            get => _bandwidthLimit;
            set
            {
                if (_bandwidthLimit == value) return;
                _bandwidthLimit = value <= 0 ? Infinite : value;
                _bandwidth ??= new Bandwidth();
                _bandwidth.BandwidthLimit = _bandwidthLimit;
                _bandwidthChangeTokenSource.Cancel();
                _bandwidthChangeTokenSource.Dispose();
                _bandwidthChangeTokenSource = new();
            }
        }
        /// 
        public override bool CanRead => _baseStream.CanRead;
        /// 
        public override bool CanSeek => _baseStream.CanSeek;
        /// 
        public override bool CanWrite => _baseStream.CanWrite;
        /// 
        public override long Length => _baseStream.Length;
        /// 
        public override long Position
        {
            get => _baseStream.Position;
            set => _baseStream.Position = value;
        }
        /// 
        public override void Flush()
        {
            _baseStream.Flush();
        }
        /// 
        public override long Seek(long offset, SeekOrigin origin)
        {
            return _baseStream.Seek(offset, origin);
        }
        /// 
        public override void SetLength(long value)
        {
            _baseStream.SetLength(value);
        }
        /// 
        public override int Read(byte[] buffer, int offset, int count)
        {
            Throttle(count).Wait();
            return _baseStream.Read(buffer, offset, count);
        }
        public override async Task ReadAsync(byte[] buffer, int offset, int count,
            CancellationToken cancellationToken)
        {
            await Throttle(count, cancellationToken).ConfigureAwait(false);
            return await _baseStream.ReadAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
        }
        /// 
        public override void Write(byte[] buffer, int offset, int count)
        {
            Throttle(count).Wait();
            _baseStream.Write(buffer, offset, count);
        }
        /// 
        public override async Task WriteAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
        {
            await Throttle(count, cancellationToken).ConfigureAwait(false);
            await _baseStream.WriteAsync(buffer, offset, count, cancellationToken).ConfigureAwait(false);
        }
        public override void Close()
        {
            _baseStream.Close();
            base.Close();
        }
        private async Task Throttle(int transmissionVolume, CancellationToken token = default)
        {
            // Make sure the buffer isn't empty.
            if (BandwidthLimit > 0 && transmissionVolume > 0)
            {
                // Calculate the time to sleep.
                _bandwidth.CalculateSpeed(transmissionVolume);
                await Sleep(_bandwidth.PopSpeedRetrieveTime(), token).ConfigureAwait(false);
            }
        }
        private async Task Sleep(int time, CancellationToken token = default)
        {
            try
            {
                if (time > 0)
                {
                    var bandWidthtoken = _bandwidthChangeTokenSource.Token;
                    var linked = CancellationTokenSource.CreateLinkedTokenSource(token, bandWidthtoken).Token;
                    await Task.Delay(time, linked).ConfigureAwait(false);
                }
            }
            catch (TaskCanceledException)
            {
                // ignore
            }
        }
        /// 
        public override string ToString()
        {
            return _baseStream?.ToString() ?? string.Empty;
        }
        private sealed class Bandwidth
        {
            private long _count;
            private int _lastSecondCheckpoint;
            private long _lastTransferredBytesCount;
            private int _speedRetrieveTime;
            public double Speed { get; private set; }
            public double AverageSpeed { get; private set; }
            public long BandwidthLimit { get; set; }
            public Bandwidth()
            {
                BandwidthLimit = long.MaxValue;
                Reset();
            }
            public void CalculateSpeed(long receivedBytesCount)
            {
                int elapsedTime = Environment.TickCount - _lastSecondCheckpoint + 1;
                receivedBytesCount = Interlocked.Add(ref _lastTransferredBytesCount, receivedBytesCount);
                double momentSpeed = receivedBytesCount * 1000 / (double)elapsedTime; // B/s
                if (1000 < elapsedTime)
                {
                    Speed = momentSpeed;
                    AverageSpeed = ((AverageSpeed * _count) + Speed) / (_count + 1);
                    _count++;
                    SecondCheckpoint();
                }
                if (momentSpeed >= BandwidthLimit)
                {
                    var expectedTime = receivedBytesCount * 1000 / BandwidthLimit;
                    Interlocked.Add(ref _speedRetrieveTime, (int)expectedTime - elapsedTime);
                }
            }
            public int PopSpeedRetrieveTime()
            {
                return Interlocked.Exchange(ref _speedRetrieveTime, 0);
            }
            public void Reset()
            {
                SecondCheckpoint();
                _count = 0;
                Speed = 0;
                AverageSpeed = 0;
            }
            private void SecondCheckpoint()
            {
                Interlocked.Exchange(ref _lastSecondCheckpoint, Environment.TickCount);
                Interlocked.Exchange(ref _lastTransferredBytesCount, 0);
            }
        }
    }
}