523 lines
21 KiB
C#
523 lines
21 KiB
C#
using Microsoft.Extensions.Logging;
|
|
using Microsoft.Extensions.Options;
|
|
using System.Net.Http.Headers;
|
|
using System.Text;
|
|
using System.Text.Json;
|
|
|
|
namespace MarketAlly.Replicate.Maui
|
|
{
|
|
public class ReplicateTransformer : IReplicateTransformer
|
|
{
|
|
private readonly HttpClient _httpClient;
|
|
private readonly ReplicateSettings _settings;
|
|
private readonly ILogger<ReplicateTransformer>? _logger;
|
|
private readonly JsonSerializerOptions _jsonOptions;
|
|
private const string BaseUrl = "https://api.replicate.com/v1";
|
|
|
|
/// <inheritdoc />
|
|
public event EventHandler<PredictionCreatedEventArgs>? PredictionCreated;
|
|
|
|
public ReplicateTransformer(
|
|
HttpClient httpClient,
|
|
IOptions<ReplicateSettings> settings,
|
|
ILogger<ReplicateTransformer>? logger = null)
|
|
{
|
|
_httpClient = httpClient;
|
|
_settings = settings.Value;
|
|
_logger = logger;
|
|
_jsonOptions = new JsonSerializerOptions
|
|
{
|
|
PropertyNamingPolicy = JsonNamingPolicy.SnakeCaseLower,
|
|
WriteIndented = false
|
|
};
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToAnimeAsync(
|
|
byte[] imageBytes,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var base64Image = Convert.ToBase64String(imageBytes);
|
|
return await TransformToAnimeFromBase64Async(base64Image, customPrompt, options, cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToAnimeFromBase64Async(
|
|
string base64Image,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (!base64Image.StartsWith("data:"))
|
|
{
|
|
base64Image = $"data:image/jpeg;base64,{base64Image}";
|
|
}
|
|
|
|
var prompt = customPrompt ?? _settings.ImagePrompt;
|
|
|
|
var input = new Dictionary<string, object>
|
|
{
|
|
{ "image", base64Image },
|
|
{ "prompt", prompt },
|
|
{ "seed", _settings.DefaultSettings.Seed },
|
|
{ "guidance_scale", _settings.DefaultSettings.GuidanceScale },
|
|
{ "strength", _settings.DefaultSettings.Strength },
|
|
{ "num_inference_steps", _settings.DefaultSettings.NumInferenceSteps }
|
|
};
|
|
|
|
return await CreatePredictionAsync(
|
|
_settings.ModelVersion,
|
|
input,
|
|
options,
|
|
_settings.TimeoutSeconds,
|
|
_settings.PollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToAnimeFromUrlAsync(
|
|
string imageUrl,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
// Pass URL directly - Replicate accepts HTTP URLs for inputs
|
|
var prompt = customPrompt ?? _settings.ImagePrompt;
|
|
|
|
var input = new Dictionary<string, object>
|
|
{
|
|
{ "image", imageUrl },
|
|
{ "prompt", prompt },
|
|
{ "seed", _settings.DefaultSettings.Seed },
|
|
{ "guidance_scale", _settings.DefaultSettings.GuidanceScale },
|
|
{ "strength", _settings.DefaultSettings.Strength },
|
|
{ "num_inference_steps", _settings.DefaultSettings.NumInferenceSteps }
|
|
};
|
|
|
|
return await CreatePredictionAsync(
|
|
_settings.ModelVersion,
|
|
input,
|
|
options,
|
|
_settings.TimeoutSeconds,
|
|
_settings.PollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToVideoAsync(
|
|
byte[] imageBytes,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var base64Image = Convert.ToBase64String(imageBytes);
|
|
return await TransformToVideoFromBase64Async(base64Image, customPrompt, options, cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToVideoFromBase64Async(
|
|
string base64Image,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (!base64Image.StartsWith("data:"))
|
|
{
|
|
base64Image = $"data:image/jpeg;base64,{base64Image}";
|
|
}
|
|
|
|
var prompt = customPrompt ?? _settings.VideoPrompt;
|
|
|
|
var input = new Dictionary<string, object>
|
|
{
|
|
{ "prompt", prompt },
|
|
{ "image", base64Image },
|
|
{ "duration", _settings.VideoSettings.Duration },
|
|
{ "size", _settings.VideoSettings.Size },
|
|
{ "enable_prompt_expansion", _settings.VideoSettings.EnablePromptExpansion }
|
|
};
|
|
|
|
if (_settings.VideoSettings.Seed.HasValue)
|
|
{
|
|
input["seed"] = _settings.VideoSettings.Seed.Value;
|
|
}
|
|
|
|
if (!string.IsNullOrEmpty(_settings.VideoSettings.AudioUrl))
|
|
{
|
|
input["audio"] = _settings.VideoSettings.AudioUrl;
|
|
}
|
|
|
|
if (!string.IsNullOrEmpty(_settings.VideoSettings.NegativePrompt))
|
|
{
|
|
input["negative_prompt"] = _settings.VideoSettings.NegativePrompt;
|
|
}
|
|
|
|
return await CreatePredictionAsync(
|
|
_settings.VideoModelVersion,
|
|
input,
|
|
options,
|
|
_settings.VideoTimeoutSeconds,
|
|
_settings.VideoPollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> TransformToVideoFromUrlAsync(
|
|
string imageUrl,
|
|
string? customPrompt = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var prompt = customPrompt ?? _settings.VideoPrompt;
|
|
|
|
var input = new Dictionary<string, object>
|
|
{
|
|
{ "prompt", prompt },
|
|
{ "image", imageUrl },
|
|
{ "duration", _settings.VideoSettings.Duration },
|
|
{ "size", _settings.VideoSettings.Size },
|
|
{ "enable_prompt_expansion", _settings.VideoSettings.EnablePromptExpansion }
|
|
};
|
|
|
|
if (_settings.VideoSettings.Seed.HasValue)
|
|
{
|
|
input["seed"] = _settings.VideoSettings.Seed.Value;
|
|
}
|
|
|
|
if (!string.IsNullOrEmpty(_settings.VideoSettings.AudioUrl))
|
|
{
|
|
input["audio"] = _settings.VideoSettings.AudioUrl;
|
|
}
|
|
|
|
if (!string.IsNullOrEmpty(_settings.VideoSettings.NegativePrompt))
|
|
{
|
|
input["negative_prompt"] = _settings.VideoSettings.NegativePrompt;
|
|
}
|
|
|
|
return await CreatePredictionAsync(
|
|
_settings.VideoModelVersion,
|
|
input,
|
|
options,
|
|
_settings.VideoTimeoutSeconds,
|
|
_settings.VideoPollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task CancelPredictionAsync(string predictionId, CancellationToken cancellationToken = default)
|
|
{
|
|
var url = $"{BaseUrl}/predictions/{predictionId}/cancel";
|
|
_logger?.LogInformation("Canceling prediction {PredictionId}", predictionId);
|
|
|
|
using var request = new HttpRequestMessage(HttpMethod.Post, url);
|
|
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
|
|
|
|
var response = await _httpClient.SendAsync(request, cancellationToken);
|
|
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
_logger?.LogError("Failed to cancel prediction: {StatusCode} - {Response}", response.StatusCode, responseBody);
|
|
throw new ReplicateApiException($"Failed to cancel prediction: {response.StatusCode}", response.StatusCode, responseBody);
|
|
}
|
|
|
|
_logger?.LogInformation("Prediction {PredictionId} canceled", predictionId);
|
|
}
|
|
|
|
public async Task<PredictionResult> GetPredictionAsync(string predictionId, CancellationToken cancellationToken = default)
|
|
{
|
|
var url = $"{BaseUrl}/predictions/{predictionId}";
|
|
|
|
using var request = new HttpRequestMessage(HttpMethod.Get, url);
|
|
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
|
|
|
|
var response = await _httpClient.SendAsync(request, cancellationToken);
|
|
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
throw new ReplicateApiException($"Failed to get prediction: {response.StatusCode}", response.StatusCode, responseBody);
|
|
}
|
|
|
|
// Check for auth errors in response body (API sometimes returns 200 with error)
|
|
if (responseBody.Contains("\"status\":401") || responseBody.Contains("Unauthenticated"))
|
|
{
|
|
throw new ReplicateApiException("Authentication failed - check your API token", System.Net.HttpStatusCode.Unauthorized, responseBody);
|
|
}
|
|
|
|
return ParsePredictionResponse(responseBody);
|
|
}
|
|
|
|
private async Task<PredictionResult> CreatePredictionAsync(
|
|
string version,
|
|
Dictionary<string, object> input,
|
|
PredictionOptions? options,
|
|
int timeoutSeconds,
|
|
int pollingDelayMs,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
var payload = new Dictionary<string, object>
|
|
{
|
|
{ "version", version },
|
|
{ "input", input }
|
|
};
|
|
|
|
// Add webhook configuration if provided
|
|
if (!string.IsNullOrEmpty(options?.WebhookUrl))
|
|
{
|
|
payload["webhook"] = options.WebhookUrl;
|
|
if (options.WebhookEventsFilter?.Length > 0)
|
|
{
|
|
payload["webhook_events_filter"] = options.WebhookEventsFilter;
|
|
}
|
|
}
|
|
|
|
try
|
|
{
|
|
_logger?.LogInformation("Creating prediction with version {Version}", version);
|
|
|
|
using var request = new HttpRequestMessage(HttpMethod.Post, $"{BaseUrl}/predictions");
|
|
request.Headers.Authorization = new AuthenticationHeaderValue(_settings.AuthScheme, _settings.ApiToken);
|
|
request.Content = new StringContent(JsonSerializer.Serialize(payload, _jsonOptions), Encoding.UTF8, "application/json");
|
|
|
|
// Add sync mode header if requested
|
|
if (options?.SyncModeWaitSeconds.HasValue == true)
|
|
{
|
|
var waitSeconds = Math.Clamp(options.SyncModeWaitSeconds.Value, 1, 60);
|
|
request.Headers.Add("Prefer", $"wait={waitSeconds}");
|
|
}
|
|
|
|
var response = await _httpClient.SendAsync(request, cancellationToken);
|
|
var responseBody = await response.Content.ReadAsStringAsync(cancellationToken);
|
|
|
|
if (!response.IsSuccessStatusCode)
|
|
{
|
|
_logger?.LogError("Replicate API error: {StatusCode} - {Response}", response.StatusCode, responseBody);
|
|
throw new ReplicateApiException($"Replicate API Error: {response.StatusCode}", response.StatusCode, responseBody);
|
|
}
|
|
|
|
var result = ParsePredictionResponse(responseBody);
|
|
_logger?.LogInformation("Prediction created with ID: {PredictionId}, Status: {Status}", result.Id, result.Status);
|
|
|
|
// Raise event immediately so tracking can start before polling
|
|
PredictionCreated?.Invoke(this, new PredictionCreatedEventArgs(result));
|
|
|
|
// If webhook only mode or already completed (sync mode success), return immediately
|
|
if (options?.WebhookOnly == true || result.IsCompleted)
|
|
{
|
|
return result;
|
|
}
|
|
|
|
// Poll for completion
|
|
return await PollForCompletionAsync(result.Id, timeoutSeconds, pollingDelayMs, cancellationToken);
|
|
}
|
|
catch (ReplicateApiException)
|
|
{
|
|
throw;
|
|
}
|
|
catch (OperationCanceledException)
|
|
{
|
|
throw;
|
|
}
|
|
catch (Exception ex)
|
|
{
|
|
_logger?.LogError(ex, "Error creating prediction");
|
|
throw;
|
|
}
|
|
}
|
|
|
|
private async Task<PredictionResult> PollForCompletionAsync(
|
|
string predictionId,
|
|
int timeoutSeconds,
|
|
int pollingDelayMs,
|
|
CancellationToken cancellationToken)
|
|
{
|
|
using var timeoutCts = new CancellationTokenSource(TimeSpan.FromSeconds(timeoutSeconds));
|
|
using var linkedCts = CancellationTokenSource.CreateLinkedTokenSource(cancellationToken, timeoutCts.Token);
|
|
|
|
var attempts = 0;
|
|
|
|
while (!linkedCts.Token.IsCancellationRequested)
|
|
{
|
|
await Task.Delay(pollingDelayMs, linkedCts.Token);
|
|
attempts++;
|
|
|
|
var result = await GetPredictionAsync(predictionId, linkedCts.Token);
|
|
_logger?.LogDebug("Status check {Attempt}: {Status}", attempts, result.Status);
|
|
|
|
if (result.IsCompleted)
|
|
{
|
|
if (result.Status == "failed")
|
|
{
|
|
_logger?.LogError("Prediction failed: {Error}", result.Error);
|
|
throw new ReplicateTransformationException($"Prediction failed: {result.Error}");
|
|
}
|
|
if (result.Status == "canceled")
|
|
{
|
|
_logger?.LogWarning("Prediction was canceled");
|
|
throw new ReplicateTransformationException("Prediction was canceled");
|
|
}
|
|
return result;
|
|
}
|
|
}
|
|
|
|
throw new TimeoutException($"Prediction timed out after {timeoutSeconds} seconds");
|
|
}
|
|
|
|
private PredictionResult ParsePredictionResponse(string responseBody)
|
|
{
|
|
var json = JsonDocument.Parse(responseBody);
|
|
var root = json.RootElement;
|
|
|
|
var result = new PredictionResult
|
|
{
|
|
Id = root.GetProperty("id").GetString() ?? string.Empty,
|
|
Status = root.GetProperty("status").GetString() ?? string.Empty
|
|
};
|
|
|
|
// Parse output
|
|
if (root.TryGetProperty("output", out var output) && output.ValueKind != JsonValueKind.Null)
|
|
{
|
|
if (output.ValueKind == JsonValueKind.Array)
|
|
{
|
|
var outputs = new List<string>();
|
|
foreach (var item in output.EnumerateArray())
|
|
{
|
|
if (item.ValueKind == JsonValueKind.String)
|
|
{
|
|
outputs.Add(item.GetString()!);
|
|
}
|
|
}
|
|
result.Outputs = outputs.ToArray();
|
|
result.Output = outputs.FirstOrDefault();
|
|
}
|
|
else if (output.ValueKind == JsonValueKind.String)
|
|
{
|
|
result.Output = output.GetString();
|
|
result.Outputs = result.Output != null ? new[] { result.Output } : null;
|
|
}
|
|
}
|
|
|
|
// Parse error
|
|
if (root.TryGetProperty("error", out var error) && error.ValueKind == JsonValueKind.String)
|
|
{
|
|
result.Error = error.GetString();
|
|
}
|
|
|
|
// Parse metrics
|
|
if (root.TryGetProperty("metrics", out var metrics) && metrics.ValueKind == JsonValueKind.Object)
|
|
{
|
|
result.Metrics = new PredictionMetrics();
|
|
if (metrics.TryGetProperty("predict_time", out var predictTime))
|
|
{
|
|
result.Metrics.PredictTime = predictTime.GetDouble();
|
|
}
|
|
if (metrics.TryGetProperty("total_time", out var totalTime))
|
|
{
|
|
result.Metrics.TotalTime = totalTime.GetDouble();
|
|
}
|
|
}
|
|
|
|
// Parse timestamps
|
|
if (root.TryGetProperty("created_at", out var createdAt) && createdAt.ValueKind == JsonValueKind.String)
|
|
{
|
|
if (DateTimeOffset.TryParse(createdAt.GetString(), out var dt))
|
|
result.CreatedAt = dt;
|
|
}
|
|
if (root.TryGetProperty("started_at", out var startedAt) && startedAt.ValueKind == JsonValueKind.String)
|
|
{
|
|
if (DateTimeOffset.TryParse(startedAt.GetString(), out var dt))
|
|
result.StartedAt = dt;
|
|
}
|
|
if (root.TryGetProperty("completed_at", out var completedAt) && completedAt.ValueKind == JsonValueKind.String)
|
|
{
|
|
if (DateTimeOffset.TryParse(completedAt.GetString(), out var dt))
|
|
result.CompletedAt = dt;
|
|
}
|
|
|
|
// Parse cancel URL
|
|
if (root.TryGetProperty("urls", out var urls) && urls.ValueKind == JsonValueKind.Object)
|
|
{
|
|
if (urls.TryGetProperty("cancel", out var cancelUrl) && cancelUrl.ValueKind == JsonValueKind.String)
|
|
{
|
|
result.CancelUrl = cancelUrl.GetString();
|
|
}
|
|
}
|
|
|
|
return result;
|
|
}
|
|
|
|
#region Preset-based Methods
|
|
|
|
public async Task<PredictionResult> RunPresetAsync(
|
|
ModelPreset preset,
|
|
byte[] imageBytes,
|
|
string? customPrompt = null,
|
|
Dictionary<string, object>? customParameters = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var base64Image = Convert.ToBase64String(imageBytes);
|
|
return await RunPresetFromBase64Async(preset, base64Image, customPrompt, customParameters, options, cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> RunPresetFromBase64Async(
|
|
ModelPreset preset,
|
|
string base64Image,
|
|
string? customPrompt = null,
|
|
Dictionary<string, object>? customParameters = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
if (!base64Image.StartsWith("data:"))
|
|
{
|
|
base64Image = $"data:image/jpeg;base64,{base64Image}";
|
|
}
|
|
|
|
var input = preset.BuildInput(customPrompt, base64Image, customParameters);
|
|
|
|
return await CreatePredictionAsync(
|
|
preset.ModelVersion,
|
|
input,
|
|
options,
|
|
preset.TimeoutSeconds,
|
|
preset.PollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> RunPresetFromUrlAsync(
|
|
ModelPreset preset,
|
|
string imageUrl,
|
|
string? customPrompt = null,
|
|
Dictionary<string, object>? customParameters = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var input = preset.BuildInput(customPrompt, imageUrl, customParameters);
|
|
|
|
return await CreatePredictionAsync(
|
|
preset.ModelVersion,
|
|
input,
|
|
options,
|
|
preset.TimeoutSeconds,
|
|
preset.PollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
public async Task<PredictionResult> RunPresetTextOnlyAsync(
|
|
ModelPreset preset,
|
|
string prompt,
|
|
Dictionary<string, object>? customParameters = null,
|
|
PredictionOptions? options = null,
|
|
CancellationToken cancellationToken = default)
|
|
{
|
|
var input = preset.BuildInput(prompt, imageData: null, customParameters);
|
|
|
|
return await CreatePredictionAsync(
|
|
preset.ModelVersion,
|
|
input,
|
|
options,
|
|
preset.TimeoutSeconds,
|
|
preset.PollingDelayMs,
|
|
cancellationToken);
|
|
}
|
|
|
|
#endregion
|
|
}
|
|
}
|