using System.Net.Http.Headers; using System.Security.Claims; using System.Text.Json; using Microsoft.AspNetCore.Components.Authorization; using Microsoft.Extensions.Logging; using WatchIt.Common.Model.Accounts; using WatchIt.Website.Services.Tokens; using WatchIt.Website.Services.Client.Accounts; namespace WatchIt.Website.Services.Authentication; public class JWTAuthenticationStateProvider : AuthenticationStateProvider { #region SERVICES private readonly HttpClient _httpClient; private readonly ILogger _logger; private readonly ITokensService _tokensService; private readonly IAccountsClientService _accountsService; #endregion #region CONSTRUCTORS public JWTAuthenticationStateProvider(HttpClient httpClient, ILogger logger, ITokensService tokensService, IAccountsClientService accountsService) { _httpClient = httpClient; _logger = logger; _tokensService = tokensService; _accountsService = accountsService; } #endregion #region PUBLIC METHODS public override async Task GetAuthenticationStateAsync() { AuthenticationState state = new AuthenticationState(new ClaimsPrincipal(new ClaimsIdentity())); Task accessTokenTask = _tokensService.GetAccessToken(); Task refreshTokenTask = _tokensService.GetRefreshToken(); await Task.WhenAll(accessTokenTask, refreshTokenTask); string? accessToken = await accessTokenTask; string? refreshToken = await refreshTokenTask; bool refreshed = false; if (string.IsNullOrWhiteSpace(accessToken)) { if (string.IsNullOrWhiteSpace(refreshToken)) { return state; } string? accessTokenNew = await Refresh(refreshToken); if (string.IsNullOrWhiteSpace(accessToken)) { return state; } accessToken = accessTokenNew; refreshed = true; } IEnumerable claims = GetClaimsFromToken(accessToken); Claim? expClaim = claims.FirstOrDefault(c => c.Type == "exp"); if (expClaim is not null && ConvertFromUnixTimestamp(int.Parse(expClaim.Value)) > DateTime.UtcNow) { if (refreshed) { return state; } } else { if (string.IsNullOrWhiteSpace(refreshToken)) { return state; } string? accessTokenNew = await Refresh(refreshToken); if (accessTokenNew is null) { return state; } accessToken = accessTokenNew; } _httpClient.DefaultRequestHeaders.Authorization = new AuthenticationHeaderValue("Bearer", accessToken.Replace("\"", "")); return new AuthenticationState(new ClaimsPrincipal(new ClaimsIdentity(claims))); } #endregion #region PRIVATE METHODS private async Task Refresh(string refreshToken) { AuthenticateResponse? response = null; void SetResponse(AuthenticateResponse data) { response = data; } await _accountsService.AuthenticateRefresh(SetResponse); if (response is not null) { await _tokensService.SaveAuthenticationData(response); } else { await _tokensService.RemoveAuthenticationData(); } return response?.AccessToken; } private static IEnumerable GetClaimsFromToken(string token) { string payload = token.Split('.')[1]; switch (payload.Length % 4) { case 2: payload += "=="; break; case 3: payload += "="; break; } byte[] jsonBytes = Convert.FromBase64String(payload); Dictionary? keyValuePairs = JsonSerializer.Deserialize>(jsonBytes); if (keyValuePairs is null) { throw new Exception("Incorrect token"); } return keyValuePairs.Select(kvp => new Claim(kvp.Key, kvp.Value.ToString())); } public static DateTime ConvertFromUnixTimestamp(int timestamp) { DateTime date = new DateTime(1970, 1, 1, 0, 0, 0, 0); date = date.AddSeconds(timestamp); return date; } #endregion }