initial
This commit is contained in:
308
flask_api.py
Normal file
308
flask_api.py
Normal file
@@ -0,0 +1,308 @@
|
||||
from flask import Flask, request, jsonify, render_template, abort, send_from_directory, Response
|
||||
import logging
|
||||
import os
|
||||
from moviepy import VideoFileClip
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from flask_cors import CORS
|
||||
from typing import Dict, Any
|
||||
import threading
|
||||
# 导入API模块中的所有函数和类
|
||||
from api import (
|
||||
APIException,
|
||||
# train_voice,
|
||||
# generate_voice,
|
||||
# generate_video,
|
||||
get_video_generate_process,
|
||||
generate_digital_human,
|
||||
# download_video,
|
||||
# download_audio,
|
||||
download_generated_video,
|
||||
|
||||
# download_generated_audio,
|
||||
list_available_files,
|
||||
find_generated_files,
|
||||
Config
|
||||
)
|
||||
|
||||
# 导入文件上传模块 (假设文件存在)
|
||||
try:
|
||||
from file_upload import FileManager
|
||||
file_manager_class=FileManager()
|
||||
|
||||
|
||||
except ImportError:
|
||||
# 如果file_upload模块不存在,创建一个简单的模拟版本
|
||||
class MockFileManager:
|
||||
def copy_audio_for_tts(self, voice_file_name, temp_uuid):
|
||||
# 模拟文件复制操作,返回一个文件名
|
||||
return f"tts_copy_{voice_file_name}"
|
||||
file_manager = MockFileManager()
|
||||
logging.warning("Mock file_upload.file_manager is used. Actual file operations might fail.")
|
||||
|
||||
|
||||
# 错误处理器
|
||||
TASKS: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def create_task(task_type: str, params: Dict[str, Any]) -> str:
|
||||
"""Create a new task and store metadata"""
|
||||
task_id = str(uuid.uuid4())
|
||||
TASKS[task_id] = {
|
||||
"status": "pending",
|
||||
"type": task_type,
|
||||
"params": params,
|
||||
"result": None,
|
||||
"progress": 0
|
||||
}
|
||||
return task_id
|
||||
|
||||
def update_task(task_id: str, status: str, progress: int = None, result: Dict[str, Any] = None):
|
||||
"""Update task metadata"""
|
||||
if task_id in TASKS:
|
||||
TASKS[task_id]["status"] = status
|
||||
if progress is not None:
|
||||
TASKS[task_id]["progress"] = progress
|
||||
if result is not None:
|
||||
TASKS[task_id]["result"] = result
|
||||
|
||||
def get_task(task_id):
|
||||
return TASKS.get(task_id, None)
|
||||
|
||||
def generate_digital_human(speech_text, sample_video, sample_voice, gen_uuid):
|
||||
import time
|
||||
time.sleep(5) # simulate heavy work
|
||||
return {"code": 200, "uuid": gen_uuid}
|
||||
|
||||
|
||||
def run_extract_audio(task_id, video_path, audio_uuid):
|
||||
try:
|
||||
update_task(task_id, "running", 10)
|
||||
if not os.path.exists(video_path):
|
||||
raise FileNotFoundError(f"Video file not found: {video_path}")
|
||||
|
||||
clip = VideoFileClip(video_path)
|
||||
|
||||
# Save extracted audio into temp dir
|
||||
out_dir = Config.RESOURCE_DIR+'uploads/audio'
|
||||
audio_path = os.path.join(out_dir, f"{audio_uuid}.wav")
|
||||
clip.audio.write_audiofile(audio_path, codec="pcm_s16le")
|
||||
clip.close()
|
||||
|
||||
update_task(task_id, "completed", 100, {
|
||||
"uuid": audio_uuid,
|
||||
"download_url": f"/download/generated/audio/{audio_uuid}.wav"
|
||||
})
|
||||
except Exception as e:
|
||||
update_task(task_id, "failed", 100, {"error": str(e)})
|
||||
|
||||
# === Background runner ===
|
||||
def run_generate_task(task_id, data, gen_uuid):
|
||||
try:
|
||||
update_task(task_id, "running", 10)
|
||||
result = generate_digital_human(
|
||||
data["speech_text"],
|
||||
data["sample_video"],
|
||||
data["sample_voice"],
|
||||
gen_uuid
|
||||
)
|
||||
update_task(task_id, "completed", 100, {
|
||||
"uuid": gen_uuid,
|
||||
"download_url": f"/download/generated/video/{gen_uuid}?task_id={result.get('code')}"
|
||||
})
|
||||
except Exception as e:
|
||||
update_task(task_id, "failed", 100, {"error": str(e)})
|
||||
# --- Flask app setup ---
|
||||
app = Flask(__name__)
|
||||
app.config.from_object('api.Config')
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
CORS(app)
|
||||
|
||||
# --- Error handlers ---
|
||||
@app.errorhandler(APIException)
|
||||
def handle_api_exception(e):
|
||||
return jsonify({"error": e.message}), e.status_code
|
||||
|
||||
@app.errorhandler(404)
|
||||
def handle_not_found(e):
|
||||
return jsonify({"error": "Resource not found"}), 404
|
||||
|
||||
@app.errorhandler(500)
|
||||
def handle_server_error(e):
|
||||
return jsonify({"error": "Internal server error"}), 500
|
||||
|
||||
@app.route("/api/video/extract_audio", methods=["POST"])
|
||||
def api_extract_audio():
|
||||
data = request.json
|
||||
if not data or "video_path" not in data:
|
||||
return jsonify({"error": "Missing field 'video_path'"}), 400
|
||||
|
||||
video_path = data["video_path"]
|
||||
audio_uuid = str(uuid.uuid4())
|
||||
|
||||
# Create async task
|
||||
task_id = create_task("extract_audio", {"video_path": video_path})
|
||||
thread = threading.Thread(target=run_extract_audio, args=(task_id, video_path, audio_uuid))
|
||||
thread.start()
|
||||
|
||||
return jsonify({
|
||||
"status": "submitted",
|
||||
"task_id": task_id,
|
||||
"query_url": f"/api/task/status/{task_id}"
|
||||
})
|
||||
|
||||
# --- Example: Digital Human ---
|
||||
@app.route("/api/digital_human/generate", methods=["POST"])
|
||||
def api_generate_digital_human():
|
||||
data = request.json
|
||||
required_fields = ["speech_text", "sample_video", "sample_voice"]
|
||||
if not data or not all(f in data for f in required_fields):
|
||||
return jsonify({
|
||||
"error": f"Missing one of the required fields: {', '.join(required_fields)}"
|
||||
}), 400
|
||||
|
||||
# Create task
|
||||
task_id = create_task("digital_human", data)
|
||||
gen_uuid = data.get("uuid", str(uuid.uuid4()))
|
||||
|
||||
# Run async in background thread
|
||||
thread = threading.Thread(target=run_generate_task, args=(task_id, data, gen_uuid))
|
||||
thread.start()
|
||||
|
||||
return jsonify({
|
||||
"status": "submitted",
|
||||
"task_id": task_id,
|
||||
"query_url": f"/api/task/status/{task_id}"
|
||||
})
|
||||
# ---
|
||||
## 文件操作API
|
||||
|
||||
@app.route('/api/files/list', methods=['GET'])
|
||||
def api_list_files():
|
||||
"""
|
||||
列出可用的文件
|
||||
可选参数:
|
||||
- directory: 目录路径
|
||||
- file_type: 文件类型 ("video", "audio", "all")
|
||||
"""
|
||||
directory = request.args.get('directory', app.config['RESOURCE_DIR'])
|
||||
file_type = request.args.get('file_type', 'all')
|
||||
|
||||
result = list_available_files(directory, file_type)
|
||||
return jsonify(result)
|
||||
|
||||
@app.route('/api/files/find/<uuid>', methods=['GET'])
|
||||
def api_find_generated_files(uuid: str):
|
||||
"""
|
||||
查找指定UUID的生成文件
|
||||
"""
|
||||
files = find_generated_files(uuid)
|
||||
if not files:
|
||||
return jsonify({"message": "No generated files found for this UUID"}), 404
|
||||
return jsonify({"status": "success", "files": files})
|
||||
|
||||
|
||||
UPLOAD_FOLDER = "/mnt/docker/resource"
|
||||
# Ensure the upload folder exists; create it if it doesn't
|
||||
os.makedirs(UPLOAD_FOLDER, exist_ok=True)
|
||||
|
||||
# Set the upload folder for the Flask app
|
||||
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
|
||||
|
||||
# Define allowed video extensions for security
|
||||
ALLOWED_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'flv', 'webm'}
|
||||
|
||||
# --- Helper Function for File Type Validation ---
|
||||
def allowed_file(filename):
|
||||
"""
|
||||
Checks if the uploaded file's extension is in the ALLOWED_EXTENSIONS set.
|
||||
"""
|
||||
return '.' in filename and \
|
||||
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
|
||||
# --- Video Upload API Endpoint ---
|
||||
@app.route('/api/upload/video', methods=['POST'])
|
||||
def upload_video():
|
||||
|
||||
"""上传视频文件"""
|
||||
try:
|
||||
if 'file' not in request.files:
|
||||
return jsonify({"error": "没有文件"}), 400
|
||||
|
||||
file = request.files['file']
|
||||
custom_name = request.form.get('custom_name')
|
||||
|
||||
file_info = file_manager_class.save_file(file, 'video', custom_name)
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"message": "视频上传成功",
|
||||
"file_info": file_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Video upload failed: {str(e)}")
|
||||
return jsonify({"error": f"上传失败: {str(e)}"}), 500
|
||||
|
||||
|
||||
@app.route('/api/upload/audio', methods=['POST'])
|
||||
def upload_audio():
|
||||
|
||||
"""上传音频文件"""
|
||||
try:
|
||||
if 'file' not in request.files:
|
||||
return jsonify({"error": "没有文件"}), 400
|
||||
|
||||
file = request.files['file']
|
||||
custom_name = request.form.get('custom_name')
|
||||
|
||||
file_info = file_manager_class.save_file(file, 'audio', custom_name)
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"message": "音频上传成功",
|
||||
"file_info": file_info
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Video upload failed: {str(e)}")
|
||||
return jsonify({"error": f"上传失败: {str(e)}"}), 500
|
||||
|
||||
@app.route("/api/task/status/<task_id>", methods=["GET"])
|
||||
def api_task_status(task_id):
|
||||
task = get_task(task_id)
|
||||
if not task:
|
||||
return jsonify({"error": "Task not found"}), 404
|
||||
return jsonify(task)
|
||||
|
||||
|
||||
@app.route('/download/generated/video/<uuid>', methods=['GET'])
|
||||
def download_generated_video_route(uuid: str):
|
||||
"""
|
||||
下载指定UUID的生成的视频文件
|
||||
可选参数:
|
||||
- task_id: 任务ID
|
||||
"""
|
||||
task_id = request.args.get('task_id')
|
||||
return download_generated_video(uuid, task_id)
|
||||
|
||||
|
||||
# ---
|
||||
## 根路由和首页
|
||||
|
||||
@app.route('/')
|
||||
def home():
|
||||
"""提供一个简单的API文档首页"""
|
||||
return render_template('index.html')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
CORS(app)
|
||||
# 确保必要的目录存在
|
||||
_ = app.config['RESOURCE_DIR']
|
||||
_ = app.config['TEMP_DIR']
|
||||
_ = os.path.expanduser(app.config['VOICE_DATA_DIR'])
|
||||
_ = os.path.expanduser(app.config['FACE2FACE_TEMP_DIR'])
|
||||
|
||||
# 运行Flask应用
|
||||
app.run(host='0.0.0.0', port=5001, debug=True)
|
||||
Reference in New Issue
Block a user