280 lines
10 KiB
Python
280 lines
10 KiB
Python
"""
|
||
文件上传和管理模块
|
||
支持多种文件类型的上传、验证和存储
|
||
"""
|
||
|
||
import os
|
||
import uuid
|
||
from werkzeug.utils import secure_filename
|
||
from flask import request, jsonify
|
||
import mimetypes
|
||
from typing import Dict, List, Optional, Tuple
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
class FileManager:
|
||
"""文件管理器 - 支持多目录文件保存"""
|
||
|
||
# 允许的文件扩展名
|
||
ALLOWED_EXTENSIONS = {
|
||
'video': {'.mp4', '.avi', '.mov', '.mkv', '.wmv', '.flv', '.webm'},
|
||
'audio': {'.wav', '.mp3', '.aac', '.flac', '.ogg', '.m4a'},
|
||
'image': {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.webp'}
|
||
}
|
||
|
||
# 文件大小限制 (MB)
|
||
MAX_FILE_SIZES = {
|
||
'video': 500, # 500MB
|
||
'audio': 50, # 50MB
|
||
'image': 10 # 10MB
|
||
}
|
||
|
||
def __init__(self, upload_dir: str = "/mnt/docker/resource/uploads"):
|
||
self.upload_dir = upload_dir
|
||
# Docker容器挂载目录配置
|
||
self.docker_dirs = {
|
||
'voice_data': os.path.expanduser("~/heygem_data/voice/data"),
|
||
'face2face_data': os.path.expanduser("~/heygem_data/face2face/temp")
|
||
}
|
||
self._ensure_upload_directories()
|
||
|
||
def _ensure_upload_directories(self):
|
||
"""确保所有上传目录存在"""
|
||
# 确保本地上传目录存在
|
||
for file_type in ['video', 'audio', 'image', 'temp']:
|
||
dir_path = os.path.join(self.upload_dir, file_type)
|
||
os.makedirs(dir_path, exist_ok=True)
|
||
|
||
# 确保Docker挂载目录存在
|
||
for dir_path in self.docker_dirs.values():
|
||
os.makedirs(dir_path, exist_ok=True)
|
||
logger.info(f"Ensured Docker mount directory exists: {dir_path}")
|
||
|
||
def copy_audio_for_tts(self, filename: str, uuid: str) -> str:
|
||
"""为TTS服务复制音频文件"""
|
||
import shutil
|
||
|
||
source_path = os.path.join(self.upload_dir, 'audio', filename)
|
||
if not os.path.exists(source_path):
|
||
raise FileNotFoundError(f"Source audio file not found: {source_path}")
|
||
|
||
# 复制到TTS服务目录,使用UUID命名
|
||
tts_filename = f"{uuid}.wav"
|
||
tts_dest = os.path.join(self.docker_dirs['voice_data'], tts_filename)
|
||
|
||
try:
|
||
shutil.copy2(source_path, tts_dest)
|
||
logger.info(f"Copied audio for TTS: {source_path} -> {tts_dest}")
|
||
return tts_filename
|
||
except Exception as e:
|
||
logger.error(f"Failed to copy audio for TTS: {e}")
|
||
raise
|
||
|
||
def copy_files_for_face2face(self, video_filename: str, audio_filename: str, uuid: str) -> Tuple[str, str]:
|
||
"""为Face2Face服务复制视频和音频文件"""
|
||
import shutil
|
||
|
||
# 复制视频文件
|
||
video_source = os.path.join(self.upload_dir, 'video', video_filename)
|
||
if not os.path.exists(video_source):
|
||
raise FileNotFoundError(f"Source video file not found: {video_source}")
|
||
|
||
video_ext = os.path.splitext(video_filename)[1]
|
||
face2face_video = f"{uuid}{video_ext}"
|
||
video_dest = os.path.join(self.docker_dirs['face2face_data'], face2face_video)
|
||
|
||
# 复制音频文件
|
||
audio_source = os.path.join(self.upload_dir, 'audio', audio_filename)
|
||
if not os.path.exists(audio_source):
|
||
raise FileNotFoundError(f"Source audio file not found: {audio_source}")
|
||
|
||
audio_ext = os.path.splitext(audio_filename)[1]
|
||
face2face_audio = f"{uuid}{audio_ext}"
|
||
audio_dest = os.path.join(self.docker_dirs['face2face_data'], face2face_audio)
|
||
|
||
try:
|
||
shutil.copy2(video_source, video_dest)
|
||
shutil.copy2(audio_source, audio_dest)
|
||
logger.info(f"Copied files for Face2Face: video={face2face_video}, audio={face2face_audio}")
|
||
return face2face_video, face2face_audio
|
||
except Exception as e:
|
||
logger.error(f"Failed to copy files for Face2Face: {e}")
|
||
raise
|
||
|
||
def copy_generated_file_to_resource(self, source_path: str, filename: str, file_type: str = 'output') -> str:
|
||
"""将生成的文件复制到资源目录"""
|
||
import shutil
|
||
|
||
# 确保资源目录存在
|
||
resource_dir = "/mnt/docker/resource"
|
||
os.makedirs(resource_dir, exist_ok=True)
|
||
|
||
dest_path = os.path.join(resource_dir, filename)
|
||
|
||
try:
|
||
shutil.copy2(source_path, dest_path)
|
||
logger.info(f"Copied generated file to resource: {source_path} -> {dest_path}")
|
||
return dest_path
|
||
except Exception as e:
|
||
logger.error(f"Failed to copy generated file: {e}")
|
||
raise
|
||
|
||
def validate_file(self, file, file_type: str) -> Tuple[bool, str]:
|
||
"""验证文件"""
|
||
if not file or not file.filename:
|
||
return False, "没有选择文件"
|
||
|
||
# 检查文件扩展名
|
||
filename = secure_filename(file.filename)
|
||
file_ext = os.path.splitext(filename)[1].lower()
|
||
|
||
if file_ext not in self.ALLOWED_EXTENSIONS.get(file_type, set()):
|
||
return False, f"不支持的{file_type}文件格式: {file_ext}"
|
||
|
||
# 检查文件大小
|
||
file.seek(0, os.SEEK_END)
|
||
file_size = file.tell()
|
||
file.seek(0) # 重置文件指针
|
||
|
||
max_size = self.MAX_FILE_SIZES.get(file_type, 10) * 1024 * 1024 # 转换为字节
|
||
if file_size > max_size:
|
||
return False, f"文件大小超出限制 ({self.MAX_FILE_SIZES.get(file_type)}MB)"
|
||
|
||
return True, "文件验证通过"
|
||
|
||
def save_file(self, file, file_type: str, custom_filename: str = None) -> Dict[str, str]:
|
||
|
||
|
||
is_valid, message = self.validate_file(file, file_type)
|
||
if not is_valid:
|
||
raise ValueError(message)
|
||
|
||
# 确保目录存在
|
||
save_dir = os.path.join(self.upload_dir, file_type)
|
||
os.makedirs(save_dir, exist_ok=True)
|
||
|
||
# 生成文件名
|
||
if custom_filename:
|
||
filename = secure_filename(custom_filename)
|
||
else:
|
||
if not file.filename:
|
||
raise ValueError("上传的文件缺少文件名")
|
||
original_filename = secure_filename(file.filename)
|
||
file_ext = os.path.splitext(original_filename)[1]
|
||
filename = f"{uuid.uuid4().hex}{file_ext}"
|
||
|
||
# 拼接完整路径
|
||
|
||
file_path = os.path.join(save_dir, filename)
|
||
print(file_path)
|
||
|
||
# 保存文件
|
||
file.save(file_path)
|
||
|
||
# 获取文件信息
|
||
file_size = os.path.getsize(file_path)
|
||
mime_type, _ = mimetypes.guess_type(file_path)
|
||
|
||
logger.info(f"Saved {file_type} file to backup storage: {filename} ({file_size} bytes)")
|
||
|
||
return {
|
||
"filename": filename,
|
||
"file_path": file_path,
|
||
"file_type": file_type,
|
||
"file_size": file_size,
|
||
"mime_type": mime_type,
|
||
"relative_path": f"uploads/{file_type}/{filename}",
|
||
"download_url": f"/download/upload/{file_type}/{filename}"
|
||
}
|
||
|
||
|
||
def delete_file(self, file_path: str) -> bool:
|
||
"""删除文件"""
|
||
try:
|
||
if os.path.exists(file_path):
|
||
os.remove(file_path)
|
||
logger.info(f"Deleted file: {file_path}")
|
||
return True
|
||
return False
|
||
except Exception as e:
|
||
logger.error(f"Failed to delete file {file_path}: {str(e)}")
|
||
return False
|
||
|
||
def get_file_info(self, file_type: str, filename: str) -> Optional[Dict[str, str]]:
|
||
"""获取文件信息"""
|
||
file_path = os.path.join(self.upload_dir, file_type, filename)
|
||
if not os.path.exists(file_path):
|
||
return None
|
||
|
||
file_size = os.path.getsize(file_path)
|
||
mime_type, _ = mimetypes.guess_type(file_path)
|
||
|
||
return {
|
||
"filename": filename,
|
||
"file_path": file_path,
|
||
"file_type": file_type,
|
||
"file_size": file_size,
|
||
"mime_type": mime_type,
|
||
"relative_path": f"uploads/{file_type}/{filename}",
|
||
"download_url": f"/download/upload/{file_type}/{filename}"
|
||
}
|
||
|
||
def list_files(self, file_type: str = None) -> List[Dict[str, str]]:
|
||
"""列出文件"""
|
||
files = []
|
||
|
||
if file_type:
|
||
file_types = [file_type]
|
||
else:
|
||
file_types = ['video', 'audio', 'image']
|
||
|
||
for ft in file_types:
|
||
type_dir = os.path.join(self.upload_dir, ft)
|
||
if os.path.exists(type_dir):
|
||
for filename in os.listdir(type_dir):
|
||
file_info = self.get_file_info(ft, filename)
|
||
if file_info:
|
||
files.append(file_info)
|
||
|
||
return files
|
||
|
||
def cleanup_temp_files(self, older_than_hours: int = 24):
|
||
"""清理临时文件"""
|
||
import time
|
||
|
||
temp_dir = os.path.join(self.upload_dir, 'temp')
|
||
if not os.path.exists(temp_dir):
|
||
return
|
||
|
||
current_time = time.time()
|
||
cutoff_time = current_time - (older_than_hours * 3600)
|
||
|
||
for filename in os.listdir(temp_dir):
|
||
file_path = os.path.join(temp_dir, filename)
|
||
if os.path.isfile(file_path):
|
||
file_mtime = os.path.getmtime(file_path)
|
||
if file_mtime < cutoff_time:
|
||
self.delete_file(file_path)
|
||
logger.info(f"Cleaned up temp file: {filename}")
|
||
|
||
# 全局文件管理器实例
|
||
file_manager = FileManager()
|
||
|
||
def save_uploaded_file(file, file_type: str, custom_filename: str = None) -> Dict[str, str]:
|
||
"""保存上传的文件"""
|
||
return file_manager.save_file(file, file_type, custom_filename)
|
||
|
||
def get_uploaded_file_info(file_type: str, filename: str) -> Optional[Dict[str, str]]:
|
||
"""获取上传文件信息"""
|
||
return file_manager.get_file_info(file_type, filename)
|
||
|
||
def list_uploaded_files(file_type: str = None) -> List[Dict[str, str]]:
|
||
"""列出上传的文件"""
|
||
return file_manager.list_files(file_type)
|
||
|
||
def validate_uploaded_file(file, file_type: str) -> Tuple[bool, str]:
|
||
"""验证上传的文件"""
|
||
return file_manager.validate_file(file, file_type)
|