Files
digital_human_backend/flask_api.py
2025-09-05 00:40:39 +08:00

308 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from flask import Flask, request, jsonify, render_template, abort, send_from_directory, Response
import logging
import os
from moviepy import VideoFileClip
import uuid
from typing import Optional
from flask_cors import CORS
from typing import Dict, Any
import threading
# 导入API模块中的所有函数和类
from api import (
APIException,
# train_voice,
# generate_voice,
# generate_video,
get_video_generate_process,
generate_digital_human,
# download_video,
# download_audio,
download_generated_video,
# download_generated_audio,
list_available_files,
find_generated_files,
Config
)
# 导入文件上传模块 (假设文件存在)
try:
from file_upload import FileManager
file_manager_class=FileManager()
except ImportError:
# 如果file_upload模块不存在创建一个简单的模拟版本
class MockFileManager:
def copy_audio_for_tts(self, voice_file_name, temp_uuid):
# 模拟文件复制操作,返回一个文件名
return f"tts_copy_{voice_file_name}"
file_manager = MockFileManager()
logging.warning("Mock file_upload.file_manager is used. Actual file operations might fail.")
# 错误处理器
TASKS: Dict[str, Dict[str, Any]] = {}
def create_task(task_type: str, params: Dict[str, Any]) -> str:
"""Create a new task and store metadata"""
task_id = str(uuid.uuid4())
TASKS[task_id] = {
"status": "pending",
"type": task_type,
"params": params,
"result": None,
"progress": 0
}
return task_id
def update_task(task_id: str, status: str, progress: int = None, result: Dict[str, Any] = None):
"""Update task metadata"""
if task_id in TASKS:
TASKS[task_id]["status"] = status
if progress is not None:
TASKS[task_id]["progress"] = progress
if result is not None:
TASKS[task_id]["result"] = result
def get_task(task_id):
return TASKS.get(task_id, None)
def generate_digital_human(speech_text, sample_video, sample_voice, gen_uuid):
import time
time.sleep(5) # simulate heavy work
return {"code": 200, "uuid": gen_uuid}
def run_extract_audio(task_id, video_path, audio_uuid):
try:
update_task(task_id, "running", 10)
if not os.path.exists(video_path):
raise FileNotFoundError(f"Video file not found: {video_path}")
clip = VideoFileClip(video_path)
# Save extracted audio into temp dir
out_dir = Config.RESOURCE_DIR+'uploads/audio'
audio_path = os.path.join(out_dir, f"{audio_uuid}.wav")
clip.audio.write_audiofile(audio_path, codec="pcm_s16le")
clip.close()
update_task(task_id, "completed", 100, {
"uuid": audio_uuid,
"download_url": f"/download/generated/audio/{audio_uuid}.wav"
})
except Exception as e:
update_task(task_id, "failed", 100, {"error": str(e)})
# === Background runner ===
def run_generate_task(task_id, data, gen_uuid):
try:
update_task(task_id, "running", 10)
result = generate_digital_human(
data["speech_text"],
data["sample_video"],
data["sample_voice"],
gen_uuid
)
update_task(task_id, "completed", 100, {
"uuid": gen_uuid,
"download_url": f"/download/generated/video/{gen_uuid}?task_id={result.get('code')}"
})
except Exception as e:
update_task(task_id, "failed", 100, {"error": str(e)})
# --- Flask app setup ---
app = Flask(__name__)
app.config.from_object('api.Config')
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
CORS(app)
# --- Error handlers ---
@app.errorhandler(APIException)
def handle_api_exception(e):
return jsonify({"error": e.message}), e.status_code
@app.errorhandler(404)
def handle_not_found(e):
return jsonify({"error": "Resource not found"}), 404
@app.errorhandler(500)
def handle_server_error(e):
return jsonify({"error": "Internal server error"}), 500
@app.route("/api/video/extract_audio", methods=["POST"])
def api_extract_audio():
data = request.json
if not data or "video_path" not in data:
return jsonify({"error": "Missing field 'video_path'"}), 400
video_path = data["video_path"]
audio_uuid = str(uuid.uuid4())
# Create async task
task_id = create_task("extract_audio", {"video_path": video_path})
thread = threading.Thread(target=run_extract_audio, args=(task_id, video_path, audio_uuid))
thread.start()
return jsonify({
"status": "submitted",
"task_id": task_id,
"query_url": f"/api/task/status/{task_id}"
})
# --- Example: Digital Human ---
@app.route("/api/digital_human/generate", methods=["POST"])
def api_generate_digital_human():
data = request.json
required_fields = ["speech_text", "sample_video", "sample_voice"]
if not data or not all(f in data for f in required_fields):
return jsonify({
"error": f"Missing one of the required fields: {', '.join(required_fields)}"
}), 400
# Create task
task_id = create_task("digital_human", data)
gen_uuid = data.get("uuid", str(uuid.uuid4()))
# Run async in background thread
thread = threading.Thread(target=run_generate_task, args=(task_id, data, gen_uuid))
thread.start()
return jsonify({
"status": "submitted",
"task_id": task_id,
"query_url": f"/api/task/status/{task_id}"
})
# ---
## 文件操作API
@app.route('/api/files/list', methods=['GET'])
def api_list_files():
"""
列出可用的文件
可选参数:
- directory: 目录路径
- file_type: 文件类型 ("video", "audio", "all")
"""
directory = request.args.get('directory', app.config['RESOURCE_DIR'])
file_type = request.args.get('file_type', 'all')
result = list_available_files(directory, file_type)
return jsonify(result)
@app.route('/api/files/find/<uuid>', methods=['GET'])
def api_find_generated_files(uuid: str):
"""
查找指定UUID的生成文件
"""
files = find_generated_files(uuid)
if not files:
return jsonify({"message": "No generated files found for this UUID"}), 404
return jsonify({"status": "success", "files": files})
UPLOAD_FOLDER = "/mnt/docker/resource"
# Ensure the upload folder exists; create it if it doesn't
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
# Set the upload folder for the Flask app
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
# Define allowed video extensions for security
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'flv', 'webm'}
# --- Helper Function for File Type Validation ---
def allowed_file(filename):
"""
Checks if the uploaded file's extension is in the ALLOWED_EXTENSIONS set.
"""
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
# --- Video Upload API Endpoint ---
@app.route('/api/upload/video', methods=['POST'])
def upload_video():
"""上传视频文件"""
try:
if 'file' not in request.files:
return jsonify({"error": "没有文件"}), 400
file = request.files['file']
custom_name = request.form.get('custom_name')
file_info = file_manager_class.save_file(file, 'video', custom_name)
return jsonify({
"success": True,
"message": "视频上传成功",
"file_info": file_info
})
except Exception as e:
logger.error(f"Video upload failed: {str(e)}")
return jsonify({"error": f"上传失败: {str(e)}"}), 500
@app.route('/api/upload/audio', methods=['POST'])
def upload_audio():
"""上传音频文件"""
try:
if 'file' not in request.files:
return jsonify({"error": "没有文件"}), 400
file = request.files['file']
custom_name = request.form.get('custom_name')
file_info = file_manager_class.save_file(file, 'audio', custom_name)
return jsonify({
"success": True,
"message": "音频上传成功",
"file_info": file_info
})
except Exception as e:
logger.error(f"Video upload failed: {str(e)}")
return jsonify({"error": f"上传失败: {str(e)}"}), 500
@app.route("/api/task/status/<task_id>", methods=["GET"])
def api_task_status(task_id):
task = get_task(task_id)
if not task:
return jsonify({"error": "Task not found"}), 404
return jsonify(task)
@app.route('/download/generated/video/<uuid>', methods=['GET'])
def download_generated_video_route(uuid: str):
"""
下载指定UUID的生成的视频文件
可选参数:
- task_id: 任务ID
"""
task_id = request.args.get('task_id')
return download_generated_video(uuid, task_id)
# ---
## 根路由和首页
@app.route('/')
def home():
"""提供一个简单的API文档首页"""
return render_template('index.html')
if __name__ == '__main__':
CORS(app)
# 确保必要的目录存在
_ = app.config['RESOURCE_DIR']
_ = app.config['TEMP_DIR']
_ = os.path.expanduser(app.config['VOICE_DATA_DIR'])
_ = os.path.expanduser(app.config['FACE2FACE_TEMP_DIR'])
# 运行Flask应用
app.run(host='0.0.0.0', port=5001, debug=True)