308 lines
9.2 KiB
Python
308 lines
9.2 KiB
Python
from flask import Flask, request, jsonify, render_template, abort, send_from_directory, Response
|
||
import logging
|
||
import os
|
||
from moviepy import VideoFileClip
|
||
import uuid
|
||
from typing import Optional
|
||
from flask_cors import CORS
|
||
from typing import Dict, Any
|
||
import threading
|
||
# 导入API模块中的所有函数和类
|
||
from api import (
|
||
APIException,
|
||
# train_voice,
|
||
# generate_voice,
|
||
# generate_video,
|
||
get_video_generate_process,
|
||
generate_digital_human,
|
||
# download_video,
|
||
# download_audio,
|
||
download_generated_video,
|
||
|
||
# download_generated_audio,
|
||
list_available_files,
|
||
find_generated_files,
|
||
Config
|
||
)
|
||
|
||
# 导入文件上传模块 (假设文件存在)
|
||
try:
|
||
from file_upload import FileManager
|
||
file_manager_class=FileManager()
|
||
|
||
|
||
except ImportError:
|
||
# 如果file_upload模块不存在,创建一个简单的模拟版本
|
||
class MockFileManager:
|
||
def copy_audio_for_tts(self, voice_file_name, temp_uuid):
|
||
# 模拟文件复制操作,返回一个文件名
|
||
return f"tts_copy_{voice_file_name}"
|
||
file_manager = MockFileManager()
|
||
logging.warning("Mock file_upload.file_manager is used. Actual file operations might fail.")
|
||
|
||
|
||
# 错误处理器
|
||
TASKS: Dict[str, Dict[str, Any]] = {}
|
||
|
||
def create_task(task_type: str, params: Dict[str, Any]) -> str:
|
||
"""Create a new task and store metadata"""
|
||
task_id = str(uuid.uuid4())
|
||
TASKS[task_id] = {
|
||
"status": "pending",
|
||
"type": task_type,
|
||
"params": params,
|
||
"result": None,
|
||
"progress": 0
|
||
}
|
||
return task_id
|
||
|
||
def update_task(task_id: str, status: str, progress: int = None, result: Dict[str, Any] = None):
|
||
"""Update task metadata"""
|
||
if task_id in TASKS:
|
||
TASKS[task_id]["status"] = status
|
||
if progress is not None:
|
||
TASKS[task_id]["progress"] = progress
|
||
if result is not None:
|
||
TASKS[task_id]["result"] = result
|
||
|
||
def get_task(task_id):
|
||
return TASKS.get(task_id, None)
|
||
|
||
def generate_digital_human(speech_text, sample_video, sample_voice, gen_uuid):
|
||
import time
|
||
time.sleep(5) # simulate heavy work
|
||
return {"code": 200, "uuid": gen_uuid}
|
||
|
||
|
||
def run_extract_audio(task_id, video_path, audio_uuid):
|
||
try:
|
||
update_task(task_id, "running", 10)
|
||
if not os.path.exists(video_path):
|
||
raise FileNotFoundError(f"Video file not found: {video_path}")
|
||
|
||
clip = VideoFileClip(video_path)
|
||
|
||
# Save extracted audio into temp dir
|
||
out_dir = Config.RESOURCE_DIR+'uploads/audio'
|
||
audio_path = os.path.join(out_dir, f"{audio_uuid}.wav")
|
||
clip.audio.write_audiofile(audio_path, codec="pcm_s16le")
|
||
clip.close()
|
||
|
||
update_task(task_id, "completed", 100, {
|
||
"uuid": audio_uuid,
|
||
"download_url": f"/download/generated/audio/{audio_uuid}.wav"
|
||
})
|
||
except Exception as e:
|
||
update_task(task_id, "failed", 100, {"error": str(e)})
|
||
|
||
# === Background runner ===
|
||
def run_generate_task(task_id, data, gen_uuid):
|
||
try:
|
||
update_task(task_id, "running", 10)
|
||
result = generate_digital_human(
|
||
data["speech_text"],
|
||
data["sample_video"],
|
||
data["sample_voice"],
|
||
gen_uuid
|
||
)
|
||
update_task(task_id, "completed", 100, {
|
||
"uuid": gen_uuid,
|
||
"download_url": f"/download/generated/video/{gen_uuid}?task_id={result.get('code')}"
|
||
})
|
||
except Exception as e:
|
||
update_task(task_id, "failed", 100, {"error": str(e)})
|
||
# --- Flask app setup ---
|
||
app = Flask(__name__)
|
||
app.config.from_object('api.Config')
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
CORS(app)
|
||
|
||
# --- Error handlers ---
|
||
@app.errorhandler(APIException)
|
||
def handle_api_exception(e):
|
||
return jsonify({"error": e.message}), e.status_code
|
||
|
||
@app.errorhandler(404)
|
||
def handle_not_found(e):
|
||
return jsonify({"error": "Resource not found"}), 404
|
||
|
||
@app.errorhandler(500)
|
||
def handle_server_error(e):
|
||
return jsonify({"error": "Internal server error"}), 500
|
||
|
||
@app.route("/api/video/extract_audio", methods=["POST"])
|
||
def api_extract_audio():
|
||
data = request.json
|
||
if not data or "video_path" not in data:
|
||
return jsonify({"error": "Missing field 'video_path'"}), 400
|
||
|
||
video_path = data["video_path"]
|
||
audio_uuid = str(uuid.uuid4())
|
||
|
||
# Create async task
|
||
task_id = create_task("extract_audio", {"video_path": video_path})
|
||
thread = threading.Thread(target=run_extract_audio, args=(task_id, video_path, audio_uuid))
|
||
thread.start()
|
||
|
||
return jsonify({
|
||
"status": "submitted",
|
||
"task_id": task_id,
|
||
"query_url": f"/api/task/status/{task_id}"
|
||
})
|
||
|
||
# --- Example: Digital Human ---
|
||
@app.route("/api/digital_human/generate", methods=["POST"])
|
||
def api_generate_digital_human():
|
||
data = request.json
|
||
required_fields = ["speech_text", "sample_video", "sample_voice"]
|
||
if not data or not all(f in data for f in required_fields):
|
||
return jsonify({
|
||
"error": f"Missing one of the required fields: {', '.join(required_fields)}"
|
||
}), 400
|
||
|
||
# Create task
|
||
task_id = create_task("digital_human", data)
|
||
gen_uuid = data.get("uuid", str(uuid.uuid4()))
|
||
|
||
# Run async in background thread
|
||
thread = threading.Thread(target=run_generate_task, args=(task_id, data, gen_uuid))
|
||
thread.start()
|
||
|
||
return jsonify({
|
||
"status": "submitted",
|
||
"task_id": task_id,
|
||
"query_url": f"/api/task/status/{task_id}"
|
||
})
|
||
# ---
|
||
## 文件操作API
|
||
|
||
@app.route('/api/files/list', methods=['GET'])
|
||
def api_list_files():
|
||
"""
|
||
列出可用的文件
|
||
可选参数:
|
||
- directory: 目录路径
|
||
- file_type: 文件类型 ("video", "audio", "all")
|
||
"""
|
||
directory = request.args.get('directory', app.config['RESOURCE_DIR'])
|
||
file_type = request.args.get('file_type', 'all')
|
||
|
||
result = list_available_files(directory, file_type)
|
||
return jsonify(result)
|
||
|
||
@app.route('/api/files/find/<uuid>', methods=['GET'])
|
||
def api_find_generated_files(uuid: str):
|
||
"""
|
||
查找指定UUID的生成文件
|
||
"""
|
||
files = find_generated_files(uuid)
|
||
if not files:
|
||
return jsonify({"message": "No generated files found for this UUID"}), 404
|
||
return jsonify({"status": "success", "files": files})
|
||
|
||
|
||
UPLOAD_FOLDER = "/mnt/docker/resource"
|
||
# Ensure the upload folder exists; create it if it doesn't
|
||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||
|
||
# Set the upload folder for the Flask app
|
||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||
|
||
# Define allowed video extensions for security
|
||
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'flv', 'webm'}
|
||
|
||
# --- Helper Function for File Type Validation ---
|
||
def allowed_file(filename):
|
||
"""
|
||
Checks if the uploaded file's extension is in the ALLOWED_EXTENSIONS set.
|
||
"""
|
||
return '.' in filename and \
|
||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||
|
||
# --- Video Upload API Endpoint ---
|
||
@app.route('/api/upload/video', methods=['POST'])
|
||
def upload_video():
|
||
|
||
"""上传视频文件"""
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"error": "没有文件"}), 400
|
||
|
||
file = request.files['file']
|
||
custom_name = request.form.get('custom_name')
|
||
|
||
file_info = file_manager_class.save_file(file, 'video', custom_name)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"message": "视频上传成功",
|
||
"file_info": file_info
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Video upload failed: {str(e)}")
|
||
return jsonify({"error": f"上传失败: {str(e)}"}), 500
|
||
|
||
|
||
@app.route('/api/upload/audio', methods=['POST'])
|
||
def upload_audio():
|
||
|
||
"""上传音频文件"""
|
||
try:
|
||
if 'file' not in request.files:
|
||
return jsonify({"error": "没有文件"}), 400
|
||
|
||
file = request.files['file']
|
||
custom_name = request.form.get('custom_name')
|
||
|
||
file_info = file_manager_class.save_file(file, 'audio', custom_name)
|
||
|
||
return jsonify({
|
||
"success": True,
|
||
"message": "音频上传成功",
|
||
"file_info": file_info
|
||
})
|
||
|
||
except Exception as e:
|
||
logger.error(f"Video upload failed: {str(e)}")
|
||
return jsonify({"error": f"上传失败: {str(e)}"}), 500
|
||
|
||
@app.route("/api/task/status/<task_id>", methods=["GET"])
|
||
def api_task_status(task_id):
|
||
task = get_task(task_id)
|
||
if not task:
|
||
return jsonify({"error": "Task not found"}), 404
|
||
return jsonify(task)
|
||
|
||
|
||
@app.route('/download/generated/video/<uuid>', methods=['GET'])
|
||
def download_generated_video_route(uuid: str):
|
||
"""
|
||
下载指定UUID的生成的视频文件
|
||
可选参数:
|
||
- task_id: 任务ID
|
||
"""
|
||
task_id = request.args.get('task_id')
|
||
return download_generated_video(uuid, task_id)
|
||
|
||
|
||
# ---
|
||
## 根路由和首页
|
||
|
||
@app.route('/')
|
||
def home():
|
||
"""提供一个简单的API文档首页"""
|
||
return render_template('index.html')
|
||
|
||
|
||
if __name__ == '__main__':
|
||
CORS(app)
|
||
# 确保必要的目录存在
|
||
_ = app.config['RESOURCE_DIR']
|
||
_ = app.config['TEMP_DIR']
|
||
_ = os.path.expanduser(app.config['VOICE_DATA_DIR'])
|
||
_ = os.path.expanduser(app.config['FACE2FACE_TEMP_DIR'])
|
||
|
||
# 运行Flask应用
|
||
app.run(host='0.0.0.0', port=5001, debug=True) |