Files
2025-12-10 21:59:12 -05:00

523 lines
21 KiB
C#

using Microsoft.Extensions.Logging;
using Microsoft.Extensions.Options;
using System.Net.Http.Headers;
using System.Text;
using System.Text.Json;
namespace MarketAlly.Replicate.Maui
{
public class ReplicateTransformer : IReplicateTransformer
{
private readonly HttpClient _httpClient;
private readonly ReplicateSettings _settings;
private readonly ILogger<ReplicateTransformer>? _logger;
private readonly JsonSerializerOptions _jsonOptions;
private const string BaseUrl = "https://api.replicate.com/v1";
/// <inheritdoc />
public event EventHandler<PredictionCreatedEventArgs>? PredictionCreated;
public ReplicateTransformer(
HttpClient httpClient,
IOptions<ReplicateSettings> settings,
ILogger<ReplicateTransformer>? logger = null)
{
_httpClient = httpClient;
_settings = settings.Value;
_logger = logger;
_jsonOptions = new JsonSerializerOptions
{
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
WriteIndented = false
};
}
public async Task<PredictionResult> TransformToAnimeAsync(
byte[] imageBytes,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var base64Image = Convert.ToBase64String(imageBytes);
return await TransformToAnimeFromBase64Async(base64Image, customPrompt, options, cancellationToken);
}
public async Task<PredictionResult> TransformToAnimeFromBase64Async(
string base64Image,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
if (!base64Image.StartsWith("data:"))
{
base64Image = $"data:image/jpeg;base64,{base64Image}";
}
var prompt = customPrompt ?? _settings.ImagePrompt;
var input = new Dictionary<string, object>
{
{ "image", base64Image },
{ "prompt", prompt },
{ "seed", _settings.DefaultSettings.Seed },
{ "guidance_scale", _settings.DefaultSettings.GuidanceScale },
{ "strength", _settings.DefaultSettings.Strength },
{ "num_inference_steps", _settings.DefaultSettings.NumInferenceSteps }
};
return await CreatePredictionAsync(
_settings.ModelVersion,
input,
options,
_settings.TimeoutSeconds,
_settings.PollingDelayMs,
cancellationToken);
}
public async Task<PredictionResult> TransformToAnimeFromUrlAsync(
string imageUrl,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
// Pass URL directly - Replicate accepts HTTP URLs for inputs
var prompt = customPrompt ?? _settings.ImagePrompt;
var input = new Dictionary<string, object>
{
{ "image", imageUrl },
{ "prompt", prompt },
{ "seed", _settings.DefaultSettings.Seed },
{ "guidance_scale", _settings.DefaultSettings.GuidanceScale },
{ "strength", _settings.DefaultSettings.Strength },
{ "num_inference_steps", _settings.DefaultSettings.NumInferenceSteps }
};
return await CreatePredictionAsync(
_settings.ModelVersion,
input,
options,
_settings.TimeoutSeconds,
_settings.PollingDelayMs,
cancellationToken);
}
public async Task<PredictionResult> TransformToVideoAsync(
byte[] imageBytes,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var base64Image = Convert.ToBase64String(imageBytes);
return await TransformToVideoFromBase64Async(base64Image, customPrompt, options, cancellationToken);
}
public async Task<PredictionResult> TransformToVideoFromBase64Async(
string base64Image,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
if (!base64Image.StartsWith("data:"))
{
base64Image = $"data:image/jpeg;base64,{base64Image}";
}
var prompt = customPrompt ?? _settings.VideoPrompt;
var input = new Dictionary<string, object>
{
{ "prompt", prompt },
{ "image", base64Image },
{ "duration", _settings.VideoSettings.Duration },
{ "size", _settings.VideoSettings.Size },
{ "enable_prompt_expansion", _settings.VideoSettings.EnablePromptExpansion }
};
if (_settings.VideoSettings.Seed.HasValue)
{
input["seed"] = _settings.VideoSettings.Seed.Value;
}
if (!string.IsNullOrEmpty(_settings.VideoSettings.AudioUrl))
{
input["audio"] = _settings.VideoSettings.AudioUrl;
}
if (!string.IsNullOrEmpty(_settings.VideoSettings.NegativePrompt))
{
input["negative_prompt"] = _settings.VideoSettings.NegativePrompt;
}
return await CreatePredictionAsync(
_settings.VideoModelVersion,
input,
options,
_settings.VideoTimeoutSeconds,
_settings.VideoPollingDelayMs,
cancellationToken);
}
public async Task<PredictionResult> TransformToVideoFromUrlAsync(
string imageUrl,
string? customPrompt = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var prompt = customPrompt ?? _settings.VideoPrompt;
var input = new Dictionary<string, object>
{
{ "prompt", prompt },
{ "image", imageUrl },
{ "duration", _settings.VideoSettings.Duration },
{ "size", _settings.VideoSettings.Size },
{ "enable_prompt_expansion", _settings.VideoSettings.EnablePromptExpansion }
};
if (_settings.VideoSettings.Seed.HasValue)
{
input["seed"] = _settings.VideoSettings.Seed.Value;
}
if (!string.IsNullOrEmpty(_settings.VideoSettings.AudioUrl))
{
input["audio"] = _settings.VideoSettings.AudioUrl;
}
if (!string.IsNullOrEmpty(_settings.VideoSettings.NegativePrompt))
{
input["negative_prompt"] = _settings.VideoSettings.NegativePrompt;
}
return await CreatePredictionAsync(
_settings.VideoModelVersion,
input,
options,
_settings.VideoTimeoutSeconds,
_settings.VideoPollingDelayMs,
cancellationToken);
}
public async Task CancelPredictionAsync(string predictionId, CancellationToken cancellationToken = default)
{
var url = $"{BaseUrl}/predictions/{predictionId}/cancel";
_logger?.LogInformation("Canceling prediction {PredictionId}", predictionId);
using var request = new HttpRequestMessage(HttpMethod.Post, url);
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
var response = await _httpClient.SendAsync(request, cancellationToken);
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
if (!response.IsSuccessStatusCode)
{
_logger?.LogError("Failed to cancel prediction: {StatusCode} - {Response}", response.StatusCode, responseBody);
throw new ReplicateApiException($"Failed to cancel prediction: {response.StatusCode}", response.StatusCode, responseBody);
}
_logger?.LogInformation("Prediction {PredictionId} canceled", predictionId);
}
public async Task<PredictionResult> GetPredictionAsync(string predictionId, CancellationToken cancellationToken = default)
{
var url = $"{BaseUrl}/predictions/{predictionId}";
using var request = new HttpRequestMessage(HttpMethod.Get, url);
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
var response = await _httpClient.SendAsync(request, cancellationToken);
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
if (!response.IsSuccessStatusCode)
{
throw new ReplicateApiException($"Failed to get prediction: {response.StatusCode}", response.StatusCode, responseBody);
}
// Check for auth errors in response body (API sometimes returns 200 with error)
if (responseBody.Contains("\"status\":401") || responseBody.Contains("Unauthenticated"))
{
throw new ReplicateApiException("Authentication failed - check your API token", System.Net.HttpStatusCode.Unauthorized, responseBody);
}
return ParsePredictionResponse(responseBody);
}
private async Task<PredictionResult> CreatePredictionAsync(
string version,
Dictionary<string, object> input,
PredictionOptions? options,
int timeoutSeconds,
int pollingDelayMs,
CancellationToken cancellationToken)
{
var payload = new Dictionary<string, object>
{
{ "version", version },
{ "input", input }
};
// Add webhook configuration if provided
if (!string.IsNullOrEmpty(options?.WebhookUrl))
{
payload["webhook"] = options.WebhookUrl;
if (options.WebhookEventsFilter?.Length > 0)
{
payload["webhook_events_filter"] = options.WebhookEventsFilter;
}
}
try
{
_logger?.LogInformation("Creating prediction with version {Version}", version);
using var request = new HttpRequestMessage(HttpMethod.Post, $"{BaseUrl}/predictions");
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
request.Content = new StringContent(JsonSerializer.Serialize(payload, _jsonOptions), Encoding.UTF8, "application/json");
// Add sync mode header if requested
if (options?.SyncModeWaitSeconds.HasValue == true)
{
var waitSeconds = Math.Clamp(options.SyncModeWaitSeconds.Value, 1, 60);
request.Headers.Add("Prefer", $"wait={waitSeconds}");
}
var response = await _httpClient.SendAsync(request, cancellationToken);
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
if (!response.IsSuccessStatusCode)
{
_logger?.LogError("Replicate API error: {StatusCode} - {Response}", response.StatusCode, responseBody);
throw new ReplicateApiException($"Replicate API Error: {response.StatusCode}", response.StatusCode, responseBody);
}
var result = ParsePredictionResponse(responseBody);
_logger?.LogInformation("Prediction created with ID: {PredictionId}, Status: {Status}", result.Id, result.Status);
// Raise event immediately so tracking can start before polling
PredictionCreated?.Invoke(this, new PredictionCreatedEventArgs(result));
// If webhook only mode or already completed (sync mode success), return immediately
if (options?.WebhookOnly == true || result.IsCompleted)
{
return result;
}
// Poll for completion
return await PollForCompletionAsync(result.Id, timeoutSeconds, pollingDelayMs, cancellationToken);
}
catch (ReplicateApiException)
{
throw;
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception ex)
{
_logger?.LogError(ex, "Error creating prediction");
throw;
}
}
private async Task<PredictionResult> PollForCompletionAsync(
string predictionId,
int timeoutSeconds,
int pollingDelayMs,
CancellationToken cancellationToken)
{
using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(timeoutSeconds));
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
var attempts = 0;
while (!linkedCts.Token.IsCancellationRequested)
{
await Task.Delay(pollingDelayMs, linkedCts.Token);
attempts++;
var result = await GetPredictionAsync(predictionId, linkedCts.Token);
_logger?.LogDebug("Status check {Attempt}: {Status}", attempts, result.Status);
if (result.IsCompleted)
{
if (result.Status == "failed")
{
_logger?.LogError("Prediction failed: {Error}", result.Error);
throw new ReplicateTransformationException($"Prediction failed: {result.Error}");
}
if (result.Status == "canceled")
{
_logger?.LogWarning("Prediction was canceled");
throw new ReplicateTransformationException("Prediction was canceled");
}
return result;
}
}
throw new TimeoutException($"Prediction timed out after {timeoutSeconds} seconds");
}
private PredictionResult ParsePredictionResponse(string responseBody)
{
var json = JsonDocument.Parse(responseBody);
var root = json.RootElement;
var result = new PredictionResult
{
Id = root.GetProperty("id").GetString() ?? string.Empty,
Status = root.GetProperty("status").GetString() ?? string.Empty
};
// Parse output
if (root.TryGetProperty("output", out var output) && output.ValueKind != JsonValueKind.Null)
{
if (output.ValueKind == JsonValueKind.Array)
{
var outputs = new List<string>();
foreach (var item in output.EnumerateArray())
{
if (item.ValueKind == JsonValueKind.String)
{
outputs.Add(item.GetString()!);
}
}
result.Outputs = outputs.ToArray();
result.Output = outputs.FirstOrDefault();
}
else if (output.ValueKind == JsonValueKind.String)
{
result.Output = output.GetString();
result.Outputs = result.Output != null ? new[] { result.Output } : null;
}
}
// Parse error
if (root.TryGetProperty("error", out var error) && error.ValueKind == JsonValueKind.String)
{
result.Error = error.GetString();
}
// Parse metrics
if (root.TryGetProperty("metrics", out var metrics) && metrics.ValueKind == JsonValueKind.Object)
{
result.Metrics = new PredictionMetrics();
if (metrics.TryGetProperty("predict_time", out var predictTime))
{
result.Metrics.PredictTime = predictTime.GetDouble();
}
if (metrics.TryGetProperty("total_time", out var totalTime))
{
result.Metrics.TotalTime = totalTime.GetDouble();
}
}
// Parse timestamps
if (root.TryGetProperty("created_at", out var createdAt) && createdAt.ValueKind == JsonValueKind.String)
{
if (DateTimeOffset.TryParse(createdAt.GetString(), out var dt))
result.CreatedAt = dt;
}
if (root.TryGetProperty("started_at", out var startedAt) && startedAt.ValueKind == JsonValueKind.String)
{
if (DateTimeOffset.TryParse(startedAt.GetString(), out var dt))
result.StartedAt = dt;
}
if (root.TryGetProperty("completed_at", out var completedAt) && completedAt.ValueKind == JsonValueKind.String)
{
if (DateTimeOffset.TryParse(completedAt.GetString(), out var dt))
result.CompletedAt = dt;
}
// Parse cancel URL
if (root.TryGetProperty("urls", out var urls) && urls.ValueKind == JsonValueKind.Object)
{
if (urls.TryGetProperty("cancel", out var cancelUrl) && cancelUrl.ValueKind == JsonValueKind.String)
{
result.CancelUrl = cancelUrl.GetString();
}
}
return result;
}
#region Preset-based Methods
public async Task<PredictionResult> RunPresetAsync(
ModelPreset preset,
byte[] imageBytes,
string? customPrompt = null,
Dictionary<string, object>? customParameters = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var base64Image = Convert.ToBase64String(imageBytes);
return await RunPresetFromBase64Async(preset, base64Image, customPrompt, customParameters, options, cancellationToken);
}
public async Task<PredictionResult> RunPresetFromBase64Async(
ModelPreset preset,
string base64Image,
string? customPrompt = null,
Dictionary<string, object>? customParameters = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
if (!base64Image.StartsWith("data:"))
{
base64Image = $"data:image/jpeg;base64,{base64Image}";
}
var input = preset.BuildInput(customPrompt, base64Image, customParameters);
return await CreatePredictionAsync(
preset.ModelVersion,
input,
options,
preset.TimeoutSeconds,
preset.PollingDelayMs,
cancellationToken);
}
public async Task<PredictionResult> RunPresetFromUrlAsync(
ModelPreset preset,
string imageUrl,
string? customPrompt = null,
Dictionary<string, object>? customParameters = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var input = preset.BuildInput(customPrompt, imageUrl, customParameters);
return await CreatePredictionAsync(
preset.ModelVersion,
input,
options,
preset.TimeoutSeconds,
preset.PollingDelayMs,
cancellationToken);
}
public async Task<PredictionResult> RunPresetTextOnlyAsync(
ModelPreset preset,
string prompt,
Dictionary<string, object>? customParameters = null,
PredictionOptions? options = null,
CancellationToken cancellationToken = default)
{
var input = preset.BuildInput(prompt, imageData: null, customParameters);
return await CreatePredictionAsync(
preset.ModelVersion,
input,
options,
preset.TimeoutSeconds,
preset.PollingDelayMs,
cancellationToken);
}
#endregion
}
}