Files
sqrtspace-dotnet/tests/SqrtSpace.SpaceTime.Tests/AspNetCore/CheckpointMiddlewareTests.cs

491 lines
16 KiB
C#
Raw Normal View History

2025-07-20 03:41:39 -04:00
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Text;
using System.Text.Json;
using System.Threading.Tasks;
using FluentAssertions;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Hosting;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Hosting;
using SqrtSpace.SpaceTime.AspNetCore;
using SqrtSpace.SpaceTime.Core;
using Xunit;
namespace SqrtSpace.SpaceTime.Tests.AspNetCore;
public class CheckpointMiddlewareTests : IDisposable
{
private readonly TestServer _server;
private readonly HttpClient _client;
private readonly string _checkpointDirectory;
public CheckpointMiddlewareTests()
{
_checkpointDirectory = Path.Combine(Path.GetTempPath(), "spacetime_middleware_tests", Guid.NewGuid().ToString());
Directory.CreateDirectory(_checkpointDirectory);
var builder = new WebHostBuilder()
.ConfigureServices(services =>
{
services.AddSpaceTime(options =>
{
options.EnableCheckpointing = true;
options.CheckpointDirectory = _checkpointDirectory;
options.CheckpointStrategy = CheckpointStrategy.Linear;
options.CheckpointInterval = TimeSpan.FromSeconds(5);
});
services.AddControllers();
})
.Configure(app =>
{
app.UseSpaceTime();
app.UseRouting();
app.UseEndpoints(endpoints =>
{
endpoints.MapControllers();
endpoints.MapPost("/process", ProcessRequestAsync);
endpoints.MapPost("/process-with-checkpoint", ProcessWithCheckpointAsync);
endpoints.MapGet("/stream", StreamDataAsync);
});
});
_server = new TestServer(builder);
_client = _server.CreateClient();
}
public void Dispose()
{
_client?.Dispose();
_server?.Dispose();
if (Directory.Exists(_checkpointDirectory))
{
Directory.Delete(_checkpointDirectory, true);
}
}
[Fact]
public async Task CheckpointMiddleware_AddsCheckpointFeature()
{
// Act
var response = await _client.PostAsync("/process", new StringContent("test"));
// Assert
response.StatusCode.Should().Be(HttpStatusCode.OK);
response.Headers.Should().ContainKey("X-Checkpoint-Enabled");
response.Headers.GetValues("X-Checkpoint-Enabled").First().Should().Be("true");
}
[Fact]
public async Task EnableCheckpointAttribute_EnablesCheckpointing()
{
// Arrange
var content = JsonSerializer.Serialize(new { items = Enumerable.Range(1, 20).ToList() });
// Act
var response = await _client.PostAsync("/api/checkpoint/process",
new StringContent(content, Encoding.UTF8, "application/json"));
// Assert
response.StatusCode.Should().Be(HttpStatusCode.OK);
var result = await response.Content.ReadAsStringAsync();
result.Should().Contain("processed");
result.Should().Contain("20");
// Verify checkpoint was created
var checkpointFiles = Directory.GetFiles(_checkpointDirectory, "*.json");
checkpointFiles.Should().NotBeEmpty();
}
[Fact]
public async Task CheckpointRecovery_ResumesFromCheckpoint()
{
// Arrange - First request that will fail
var checkpointId = Guid.NewGuid().ToString();
var request1 = new HttpRequestMessage(HttpMethod.Post, "/api/checkpoint/process-with-failure")
{
Headers = { { "X-Checkpoint-Id", checkpointId } },
Content = new StringContent(
JsonSerializer.Serialize(new { items = Enumerable.Range(1, 20).ToList(), failAt = 10 }),
Encoding.UTF8,
"application/json")
};
// Act - First request should fail
var response1 = await _client.SendAsync(request1);
response1.StatusCode.Should().Be(HttpStatusCode.InternalServerError);
// Act - Resume with same checkpoint ID
var request2 = new HttpRequestMessage(HttpMethod.Post, "/api/checkpoint/process-with-failure")
{
Headers = { { "X-Checkpoint-Id", checkpointId } },
Content = new StringContent(
JsonSerializer.Serialize(new { items = Enumerable.Range(1, 20).ToList() }),
Encoding.UTF8,
"application/json")
};
var response2 = await _client.SendAsync(request2);
// Assert
response2.StatusCode.Should().Be(HttpStatusCode.OK);
var result = await response2.Content.ReadAsStringAsync();
var processResult = JsonSerializer.Deserialize<ProcessResult>(result);
processResult!.ProcessedCount.Should().Be(20);
processResult.ResumedFromCheckpoint.Should().BeTrue();
processResult.StartedFrom.Should().BeGreaterThan(0);
}
[Fact]
public async Task StreamingMiddleware_ChunksLargeResponses()
{
// Act
var response = await _client.GetAsync("/stream?count=1000");
// Assert
response.StatusCode.Should().Be(HttpStatusCode.OK);
response.Headers.TransferEncodingChunked.Should().BeTrue();
var content = await response.Content.ReadAsStringAsync();
var items = JsonSerializer.Deserialize<List<StreamItem>>(content);
items.Should().HaveCount(1000);
}
[Fact]
public async Task SpaceTimeStreamingAttribute_EnablesChunking()
{
// Act
var response = await _client.GetStreamAsync("/api/streaming/large-dataset?count=100");
// Read streamed content
var items = new List<DataItem>();
using var reader = new StreamReader(response);
string? line;
while ((line = await reader.ReadLineAsync()) != null)
{
if (!string.IsNullOrWhiteSpace(line))
{
var item = JsonSerializer.Deserialize<DataItem>(line);
if (item != null)
items.Add(item);
}
}
// Assert
items.Should().HaveCount(100);
items.Select(i => i.Id).Should().BeEquivalentTo(Enumerable.Range(1, 100));
}
[Fact]
public async Task Middleware_TracksMemoryUsage()
{
// Act
var response = await _client.PostAsync("/api/memory/intensive",
new StringContent(JsonSerializer.Serialize(new { size = 1000 })));
// Assert
response.StatusCode.Should().Be(HttpStatusCode.OK);
response.Headers.Should().ContainKey("X-Memory-Before");
response.Headers.Should().ContainKey("X-Memory-After");
response.Headers.Should().ContainKey("X-Memory-Peak");
var memoryBefore = long.Parse(response.Headers.GetValues("X-Memory-Before").First());
var memoryPeak = long.Parse(response.Headers.GetValues("X-Memory-Peak").First());
memoryPeak.Should().BeGreaterThan(memoryBefore);
}
[Fact]
public async Task ConcurrentRequests_HandleCheckpointingCorrectly()
{
// Arrange
var tasks = new List<Task<HttpResponseMessage>>();
// Act
for (int i = 0; i < 5; i++)
{
var checkpointId = $"concurrent_{i}";
var request = new HttpRequestMessage(HttpMethod.Post, "/api/checkpoint/process")
{
Headers = { { "X-Checkpoint-Id", checkpointId } },
Content = new StringContent(
JsonSerializer.Serialize(new { items = Enumerable.Range(1, 10).ToList() }),
Encoding.UTF8,
"application/json")
};
tasks.Add(_client.SendAsync(request));
}
var responses = await Task.WhenAll(tasks);
// Assert
responses.Should().AllSatisfy(r => r.StatusCode.Should().Be(HttpStatusCode.OK));
// Each request should have created its own checkpoint
var checkpointFiles = Directory.GetFiles(_checkpointDirectory, "concurrent_*.json");
checkpointFiles.Should().HaveCount(5);
}
[Fact]
public async Task RequestTimeout_CheckpointsBeforeTimeout()
{
// Arrange
var checkpointId = Guid.NewGuid().ToString();
var request = new HttpRequestMessage(HttpMethod.Post, "/api/checkpoint/long-running")
{
Headers = { { "X-Checkpoint-Id", checkpointId } },
Content = new StringContent(
JsonSerializer.Serialize(new { duration = 10000 }), // 10 seconds
Encoding.UTF8,
"application/json")
};
// Act - Cancel after 2 seconds
using var cts = new System.Threading.CancellationTokenSource(TimeSpan.FromSeconds(2));
HttpResponseMessage? response = null;
try
{
response = await _client.SendAsync(request, cts.Token);
}
catch (OperationCanceledException)
{
// Expected
}
// Assert - Checkpoint should exist even though request was cancelled
await Task.Delay(500); // Give time for checkpoint to be written
var checkpointFile = Path.Combine(_checkpointDirectory, $"{checkpointId}.json");
File.Exists(checkpointFile).Should().BeTrue();
}
private static async Task ProcessRequestAsync(HttpContext context)
{
var checkpoint = context.Features.Get<ICheckpointFeature>();
context.Response.Headers.Add("X-Checkpoint-Enabled", checkpoint != null ? "true" : "false");
await context.Response.WriteAsync("Processed");
}
private static async Task ProcessWithCheckpointAsync(HttpContext context)
{
var checkpoint = context.Features.Get<ICheckpointFeature>()!;
var processed = 0;
for (int i = 1; i <= 20; i++)
{
processed = i;
if (checkpoint.CheckpointManager.ShouldCheckpoint())
{
await checkpoint.CheckpointManager.CreateCheckpointAsync(new { processed = i });
}
await Task.Delay(10); // Simulate work
}
await context.Response.WriteAsJsonAsync(new { processed });
}
private static async Task StreamDataAsync(HttpContext context)
{
var count = int.Parse(context.Request.Query["count"].FirstOrDefault() ?? "100");
var items = Enumerable.Range(1, count).Select(i => new StreamItem { Id = i, Value = $"Item {i}" });
context.Response.Headers.Add("Content-Type", "application/json");
await context.Response.WriteAsJsonAsync(items);
}
private class StreamItem
{
public int Id { get; set; }
public string Value { get; set; } = "";
}
private class DataItem
{
public int Id { get; set; }
public string Name { get; set; } = "";
public DateTime Timestamp { get; set; }
}
private class ProcessResult
{
public int ProcessedCount { get; set; }
public bool ResumedFromCheckpoint { get; set; }
public int StartedFrom { get; set; }
}
}
// Test controllers
[ApiController]
[Route("api/checkpoint")]
public class CheckpointTestController : ControllerBase
{
[HttpPost("process")]
[EnableCheckpoint]
public async Task<IActionResult> ProcessItems([FromBody] ProcessRequest request)
{
var checkpoint = HttpContext.Features.Get<ICheckpointFeature>()!;
var processedCount = 0;
foreach (var item in request.Items)
{
// Simulate processing
await Task.Delay(10);
processedCount++;
if (checkpoint.CheckpointManager.ShouldCheckpoint())
{
await checkpoint.CheckpointManager.CreateCheckpointAsync(new { processedCount, lastItem = item });
}
}
return Ok(new { processed = processedCount });
}
[HttpPost("process-with-failure")]
[EnableCheckpoint]
public async Task<IActionResult> ProcessWithFailure([FromBody] ProcessWithFailureRequest request)
{
var checkpoint = HttpContext.Features.Get<ICheckpointFeature>()!;
// Try to load previous state
var state = await checkpoint.CheckpointManager.RestoreLatestCheckpointAsync<ProcessState>();
var startFrom = state?.ProcessedCount ?? 0;
var processedCount = startFrom;
for (int i = startFrom; i < request.Items.Count; i++)
{
if (request.FailAt.HasValue && i == request.FailAt.Value)
{
throw new Exception("Simulated failure");
}
processedCount++;
if (checkpoint.CheckpointManager.ShouldCheckpoint())
{
await checkpoint.CheckpointManager.CreateCheckpointAsync(new ProcessState { ProcessedCount = processedCount });
}
}
return Ok(new ProcessResult
{
ProcessedCount = processedCount,
ResumedFromCheckpoint = startFrom > 0,
StartedFrom = startFrom
});
}
[HttpPost("long-running")]
[EnableCheckpoint(Strategy = CheckpointStrategy.Linear)]
public async Task<IActionResult> LongRunning([FromBody] LongRunningRequest request)
{
var checkpoint = HttpContext.Features.Get<ICheckpointFeature>()!;
var progress = 0;
for (int i = 0; i < request.Duration / 100; i++)
{
await Task.Delay(100);
progress++;
if (checkpoint.CheckpointManager.ShouldCheckpoint())
{
await checkpoint.CheckpointManager.CreateCheckpointAsync(new { progress });
}
}
return Ok(new { completed = progress });
}
public class ProcessRequest
{
public List<int> Items { get; set; } = new();
}
public class ProcessWithFailureRequest : ProcessRequest
{
public int? FailAt { get; set; }
}
public class LongRunningRequest
{
public int Duration { get; set; }
}
private class ProcessState
{
public int ProcessedCount { get; set; }
}
private class ProcessResult
{
public int ProcessedCount { get; set; }
public bool ResumedFromCheckpoint { get; set; }
public int StartedFrom { get; set; }
}
}
[ApiController]
[Route("api/streaming")]
public class StreamingTestController : ControllerBase
{
[HttpGet("large-dataset")]
[SpaceTimeStreaming(ChunkStrategy = ChunkStrategy.SqrtN)]
public async IAsyncEnumerable<DataItem> GetLargeDataset([FromQuery] int count = 100)
{
for (int i = 1; i <= count; i++)
{
yield return new DataItem
{
Id = i,
Name = $"Item {i}",
Timestamp = DateTime.UtcNow
};
await Task.Delay(1); // Simulate data retrieval
}
}
public class DataItem
{
public int Id { get; set; }
public string Name { get; set; } = "";
public DateTime Timestamp { get; set; }
}
}
[ApiController]
[Route("api/memory")]
public class MemoryTestController : ControllerBase
{
[HttpPost("intensive")]
public IActionResult MemoryIntensive([FromBody] MemoryRequest request)
{
// Allocate some memory
var data = new byte[request.Size * 1024]; // Size in KB
Random.Shared.NextBytes(data);
// Force GC to get accurate memory readings
GC.Collect();
GC.WaitForPendingFinalizers();
GC.Collect();
return Ok(new { allocated = data.Length });
}
public class MemoryRequest
{
public int Size { get; set; }
}
}