369 lines
13 KiB
Python
369 lines
13 KiB
Python
"""
|
||
异步视频处理和数字人合成API模块
|
||
支持文件上传、异步任务处理、模板管理等功能
|
||
"""
|
||
|
||
import asyncio
|
||
import uuid as uuid_lib
|
||
from typing import Optional, Dict, Any, List
|
||
import os
|
||
import json
|
||
import time
|
||
from enum import Enum
|
||
from dataclasses import dataclass
|
||
import threading
|
||
import queue
|
||
from datetime import datetime
|
||
import logging
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def resolve_file_path(filename: str, file_type: str) -> str:
|
||
"""
|
||
解析文件路径,支持Docker容器挂载目录
|
||
|
||
Args:
|
||
filename: 文件名(可能包含路径)
|
||
file_type: 文件类型 ('audio', 'video', 'image')
|
||
|
||
Returns:
|
||
可访问的文件路径
|
||
"""
|
||
# 如果是完整路径,直接返回
|
||
if os.path.isabs(filename) and os.path.exists(filename):
|
||
return filename
|
||
|
||
# 仅文件名的情况,尝试多个可能的路径
|
||
base_filename = os.path.basename(filename)
|
||
|
||
# 定义搜索路径优先级
|
||
search_paths = []
|
||
|
||
if file_type == 'audio':
|
||
search_paths = [
|
||
f"/mnt/docker/resource/uploads/audio/{base_filename}",
|
||
f"{os.path.expanduser('~/heygem_data/voice/data')}/{base_filename}",
|
||
f"/mnt/docker/resource/{base_filename}",
|
||
f"/mnt/docker/code/data/temp/{base_filename}"
|
||
]
|
||
elif file_type == 'video':
|
||
search_paths = [
|
||
f"/mnt/docker/resource/uploads/video/{base_filename}",
|
||
f"{os.path.expanduser('~/heygem_data/face2face/temp')}/{base_filename}",
|
||
f"/mnt/docker/resource/{base_filename}",
|
||
f"/mnt/docker/video_resource/{base_filename}"
|
||
]
|
||
elif file_type == 'image':
|
||
search_paths = [
|
||
f"/mnt/docker/resource/uploads/image/{base_filename}",
|
||
f"/mnt/docker/resource/{base_filename}"
|
||
]
|
||
|
||
# 查找第一个存在的文件
|
||
for path in search_paths:
|
||
if os.path.exists(path):
|
||
logger.info(f"Resolved {file_type} file: {filename} -> {path}")
|
||
return path
|
||
|
||
# 如果都没找到,返回原始文件名(让调用者处理错误)
|
||
logger.warning(f"Could not resolve {file_type} file path: {filename}")
|
||
return filename
|
||
|
||
class TaskStatus(Enum):
|
||
"""任务状态枚举"""
|
||
PENDING = "pending" # 等待中
|
||
PROCESSING = "processing" # 处理中
|
||
COMPLETED = "completed" # 已完成
|
||
FAILED = "failed" # 失败
|
||
CANCELLED = "cancelled" # 已取消
|
||
|
||
class TaskType(Enum):
|
||
"""任务类型枚举"""
|
||
VOICE_GENERATION = "voice_generation"
|
||
DIGITAL_HUMAN_CREATION = "digital_human_creation"
|
||
VIDEO_COMPOSITION = "video_composition"
|
||
TEMPLATE_CREATION = "template_creation"
|
||
AUDIO_EXTRACTION = "audio_extraction"
|
||
|
||
@dataclass
|
||
class Task:
|
||
"""任务数据类"""
|
||
task_id: str
|
||
task_type: TaskType
|
||
status: TaskStatus
|
||
progress: int = 0
|
||
result: Optional[Dict[str, Any]] = None
|
||
error_message: Optional[str] = None
|
||
created_at: datetime = None
|
||
updated_at: datetime = None
|
||
user_id: Optional[str] = None
|
||
input_data: Optional[Dict[str, Any]] = None
|
||
|
||
def __post_init__(self):
|
||
if self.created_at is None:
|
||
self.created_at = datetime.now()
|
||
self.updated_at = self.created_at
|
||
|
||
class TaskManager:
|
||
"""任务管理器"""
|
||
|
||
def __init__(self):
|
||
self.tasks: Dict[str, Task] = {}
|
||
self.task_queue = queue.Queue()
|
||
self.worker_thread = None
|
||
self.is_running = False
|
||
|
||
def start_worker(self):
|
||
"""启动工作线程"""
|
||
if not self.is_running:
|
||
self.is_running = True
|
||
self.worker_thread = threading.Thread(target=self._worker_loop, daemon=True)
|
||
self.worker_thread.start()
|
||
logger.info("Task worker started")
|
||
|
||
def stop_worker(self):
|
||
"""停止工作线程"""
|
||
self.is_running = False
|
||
if self.worker_thread:
|
||
self.worker_thread.join()
|
||
logger.info("Task worker stopped")
|
||
|
||
def create_task(self, task_type: TaskType, input_data: Dict[str, Any], user_id: str = None) -> str:
|
||
"""创建新任务"""
|
||
task_id = str(uuid_lib.uuid4())
|
||
task = Task(
|
||
task_id=task_id,
|
||
task_type=task_type,
|
||
status=TaskStatus.PENDING,
|
||
input_data=input_data,
|
||
user_id=user_id
|
||
)
|
||
self.tasks[task_id] = task
|
||
self.task_queue.put(task_id)
|
||
logger.info(f"Created task {task_id} of type {task_type.value}")
|
||
return task_id
|
||
|
||
def get_task(self, task_id: str) -> Optional[Task]:
|
||
"""获取任务信息"""
|
||
return self.tasks.get(task_id)
|
||
|
||
def update_task_status(self, task_id: str, status: TaskStatus, progress: int = None, result: Dict[str, Any] = None, error_message: str = None):
|
||
"""更新任务状态"""
|
||
if task_id in self.tasks:
|
||
task = self.tasks[task_id]
|
||
task.status = status
|
||
if progress is not None:
|
||
task.progress = progress
|
||
if result is not None:
|
||
task.result = result
|
||
if error_message is not None:
|
||
task.error_message = error_message
|
||
task.updated_at = datetime.now()
|
||
logger.info(f"Updated task {task_id}: status={status.value}, progress={progress}")
|
||
|
||
def _worker_loop(self):
|
||
"""工作线程主循环"""
|
||
while self.is_running:
|
||
try:
|
||
task_id = self.task_queue.get(timeout=1)
|
||
self._process_task(task_id)
|
||
self.task_queue.task_done()
|
||
except queue.Empty:
|
||
continue
|
||
except Exception as e:
|
||
logger.error(f"Error in worker loop: {str(e)}")
|
||
|
||
def _process_task(self, task_id: str):
|
||
"""处理具体任务"""
|
||
task = self.get_task(task_id)
|
||
if not task:
|
||
logger.error(f"Task {task_id} not found")
|
||
return
|
||
|
||
try:
|
||
self.update_task_status(task_id, TaskStatus.PROCESSING, 10)
|
||
|
||
if task.task_type == TaskType.VOICE_GENERATION:
|
||
result = self._process_voice_generation(task)
|
||
elif task.task_type == TaskType.DIGITAL_HUMAN_CREATION:
|
||
result = self._process_digital_human_creation(task)
|
||
elif task.task_type == TaskType.VIDEO_COMPOSITION:
|
||
result = self._process_video_composition(task)
|
||
elif task.task_type == TaskType.TEMPLATE_CREATION:
|
||
result = self._process_template_creation(task)
|
||
elif task.task_type == TaskType.AUDIO_EXTRACTION:
|
||
result = self._process_audio_extraction(task)
|
||
else:
|
||
raise ValueError(f"Unknown task type: {task.task_type}")
|
||
|
||
self.update_task_status(task_id, TaskStatus.COMPLETED, 100, result)
|
||
|
||
except Exception as e:
|
||
logger.error(f"Task {task_id} failed: {str(e)}")
|
||
self.update_task_status(task_id, TaskStatus.FAILED, error_message=str(e))
|
||
|
||
def _process_voice_generation(self, task: Task) -> Dict[str, Any]:
|
||
"""处理语音生成任务"""
|
||
import digital_human_api
|
||
|
||
input_data = task.input_data
|
||
text = input_data['text']
|
||
reference_audio = input_data.get('reference_audio')
|
||
reference_text = input_data.get('reference_text', '')
|
||
uuid = input_data.get('uuid', task.task_id[:8])
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 30)
|
||
|
||
# 如果需要先训练声音模型
|
||
if reference_audio and not reference_text:
|
||
voice_result = digital_human_api.train_voice_v2(reference_audio)
|
||
reference_audio = voice_result.get('asr_format_audio_url')
|
||
reference_text = voice_result.get('reference_audio_text', '')
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 60)
|
||
|
||
# 生成语音
|
||
audio_path = digital_human_api.generate_voice_v2(text, reference_audio, reference_text, uuid)
|
||
|
||
return {
|
||
"audio_path": audio_path,
|
||
"audio_url": f"/download/generated/audio/{uuid}",
|
||
"uuid": uuid,
|
||
"text": text
|
||
}
|
||
|
||
def _process_digital_human_creation(self, task: Task) -> Dict[str, Any]:
|
||
"""处理数字人创建任务"""
|
||
import digital_human_api
|
||
|
||
input_data = task.input_data
|
||
speech_text = input_data['speech_text']
|
||
sample_video = input_data['sample_video']
|
||
sample_voice = input_data['sample_voice']
|
||
uuid = input_data.get('uuid', task.task_id[:8])
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 20)
|
||
|
||
# 生成数字人(直接传递文件名)
|
||
result = digital_human_api.generate_digital_human_v2(
|
||
speech_text,
|
||
sample_video,
|
||
sample_voice,
|
||
uuid
|
||
)
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 80)
|
||
|
||
return {
|
||
"digital_human_result": result,
|
||
"video_url": f"/download/generated/video/{uuid}",
|
||
"audio_url": f"/download/generated/audio/{uuid}",
|
||
"uuid": uuid,
|
||
"speech_text": speech_text
|
||
}
|
||
|
||
def _process_video_composition(self, task: Task) -> Dict[str, Any]:
|
||
"""处理视频合成任务"""
|
||
import api
|
||
|
||
input_data = task.input_data
|
||
template_id = input_data['template_id']
|
||
audio_file = input_data['audio_file']
|
||
text_content = input_data.get('text_content', '')
|
||
uuid = input_data.get('uuid', task.task_id[:8])
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 40)
|
||
|
||
# 获取模板信息
|
||
template_info = self._get_template_info(template_id)
|
||
if not template_info:
|
||
raise ValueError(f"Template {template_id} not found")
|
||
|
||
# 使用模板合成视频
|
||
result = api.generate_video(template_info['video_path'], audio_file, uuid)
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 90)
|
||
|
||
return {
|
||
"composition_result": result,
|
||
"video_url": f"/download/generated/video/{uuid}",
|
||
"template_id": template_id,
|
||
"uuid": uuid
|
||
}
|
||
|
||
def _process_template_creation(self, task: Task) -> Dict[str, Any]:
|
||
"""处理模板创建任务"""
|
||
import api
|
||
|
||
input_data = task.input_data
|
||
person_image = input_data['person_image']
|
||
title_text = input_data.get('title_text', '')
|
||
title_position = tuple(input_data.get('title_position', [50, 50]))
|
||
title_font_size = input_data.get('title_font_size', 48)
|
||
background_image = input_data['background_image']
|
||
video_length = input_data.get('video_length', 10.0)
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 50)
|
||
|
||
# 生成带标题的模板视频
|
||
template_path = api.generate_video_with_title(
|
||
person_image, title_text, title_position,
|
||
title_font_size, background_image, video_length
|
||
)
|
||
|
||
# 保存模板信息
|
||
template_id = self._save_template_info({
|
||
"template_path": template_path,
|
||
"title_text": title_text,
|
||
"person_image": person_image,
|
||
"background_image": background_image,
|
||
"created_at": datetime.now().isoformat()
|
||
})
|
||
|
||
return {
|
||
"template_id": template_id,
|
||
"template_path": template_path,
|
||
"template_url": f"/download/template/{template_id}",
|
||
"title_text": title_text
|
||
}
|
||
|
||
def _process_audio_extraction(self, task: Task) -> Dict[str, Any]:
|
||
"""处理音频提取任务"""
|
||
import api
|
||
|
||
input_data = task.input_data
|
||
video_file = input_data['video_file']
|
||
|
||
self.update_task_status(task.task_id, TaskStatus.PROCESSING, 50)
|
||
|
||
# 提取文本
|
||
extracted_text = api.speech_to_text(video_file)
|
||
|
||
return {
|
||
"extracted_text": extracted_text,
|
||
"video_file": video_file
|
||
}
|
||
|
||
def _get_template_info(self, template_id: str) -> Optional[Dict[str, Any]]:
|
||
"""获取模板信息"""
|
||
template_file = f"/mnt/docker/resource/templates/{template_id}.json"
|
||
if os.path.exists(template_file):
|
||
with open(template_file, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
return None
|
||
|
||
def _save_template_info(self, template_data: Dict[str, Any]) -> str:
|
||
"""保存模板信息"""
|
||
template_id = str(uuid_lib.uuid4())
|
||
template_dir = "/mnt/docker/resource/templates"
|
||
os.makedirs(template_dir, exist_ok=True)
|
||
|
||
template_file = os.path.join(template_dir, f"{template_id}.json")
|
||
with open(template_file, 'w', encoding='utf-8') as f:
|
||
json.dump(template_data, f, ensure_ascii=False, indent=2)
|
||
|
||
return template_id
|
||
|
||
# 全局任务管理器实例
|
||
task_manager = TaskManager()
|
||
task_manager.start_worker()
|