365 lines
16 KiB
Python
365 lines
16 KiB
Python
import logging
|
||
import uuid
|
||
from rest_framework import viewsets, status
|
||
from rest_framework.decorators import action, api_view, permission_classes, parser_classes
|
||
from rest_framework.response import Response
|
||
from rest_framework.parsers import MultiPartParser, FormParser, JSONParser
|
||
from rest_framework.permissions import AllowAny
|
||
from django.conf import settings
|
||
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes
|
||
from .models import TranscriptionTask, AIEvaluation
|
||
from .serializers import TranscriptionTaskSerializer, TranscriptionUploadSerializer, AIEvaluationSerializer
|
||
from .services import AliyunTingwuService
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
@api_view(['POST'])
|
||
@permission_classes([AllowAny])
|
||
def tingwu_callback(request):
|
||
"""
|
||
处理阿里云听悟的回调消息
|
||
"""
|
||
data = request.data
|
||
logger.info(f"收到听悟回调: {data}")
|
||
|
||
# 1. 处理连通性测试消息
|
||
# 格式: {"Code": "0", "Data": {"Test": "..."}, "Message": "success", "RequestId": "..."}
|
||
if isinstance(data, dict) and 'Data' in data and 'Test' in data['Data']:
|
||
logger.info("收到听悟连通性测试请求")
|
||
return Response({'message': 'success'}, status=status.HTTP_200_OK)
|
||
|
||
# 2. 处理任务完成消息 (根据实际文档或后续调试完善)
|
||
# 通常会包含 TaskId 和 Status
|
||
# 注意:阿里云听悟回调的结构可能在 Header 或 Body 中不同,需根据实际情况调整
|
||
# 这里是一个通用的处理逻辑
|
||
task_id = data.get('TaskId')
|
||
task_status = data.get('Status')
|
||
|
||
if task_id:
|
||
try:
|
||
task = TranscriptionTask.objects.filter(task_id=task_id).first()
|
||
if task:
|
||
if task_status == 'COMPLETE':
|
||
logger.info(f"任务 {task_id} 完成,等待下一次查询刷新")
|
||
# 可以在这里直接调用 get_task_info 刷新数据,但要注意超时
|
||
elif task_status == 'FAILED':
|
||
task.status = TranscriptionTask.Status.FAILED
|
||
task.error_message = data.get('StatusText', 'Callback reported failure')
|
||
task.save()
|
||
else:
|
||
logger.warning(f"回调收到未知任务ID: {task_id}")
|
||
except Exception as e:
|
||
logger.error(f"处理回调异常: {e}")
|
||
|
||
return Response({'message': 'success'}, status=status.HTTP_200_OK)
|
||
|
||
class TranscriptionTaskViewSet(viewsets.ModelViewSet):
|
||
queryset = TranscriptionTask.objects.all()
|
||
serializer_class = TranscriptionTaskSerializer
|
||
parser_classes = (MultiPartParser, FormParser)
|
||
|
||
@extend_schema(
|
||
request={
|
||
'multipart/form-data': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'file': {
|
||
'type': 'string',
|
||
'format': 'binary'
|
||
},
|
||
'file_url': {
|
||
'type': 'string',
|
||
'description': '音频文件的URL地址'
|
||
},
|
||
'project_id': {
|
||
'type': 'integer',
|
||
'description': '关联的参赛项目ID'
|
||
}
|
||
}
|
||
}
|
||
},
|
||
responses={201: TranscriptionTaskSerializer}
|
||
)
|
||
def create(self, request, *args, **kwargs):
|
||
"""
|
||
上传音频文件并创建听悟转写任务
|
||
"""
|
||
file_obj = request.FILES.get('file')
|
||
file_url = request.data.get('file_url')
|
||
project_id = request.data.get('project_id')
|
||
|
||
if not file_obj and not file_url:
|
||
return Response({'error': '请提供文件或文件URL'}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
service = AliyunTingwuService()
|
||
if not service.bucket or not service.client:
|
||
return Response({'error': '阿里云服务未配置'}, status=status.HTTP_503_SERVICE_UNAVAILABLE)
|
||
|
||
try:
|
||
oss_url = None
|
||
if file_obj:
|
||
# 1. 上传文件到 OSS
|
||
file_extension = file_obj.name.split('.')[-1]
|
||
file_name = f"transcription/{uuid.uuid4()}.{file_extension}"
|
||
|
||
# 使用服务上传
|
||
oss_url = service.upload_to_oss(file_obj, file_name)
|
||
else:
|
||
# 使用提供的 URL
|
||
oss_url = file_url
|
||
|
||
# 2. 创建数据库记录
|
||
task_data = {
|
||
'file_url': oss_url,
|
||
'status': TranscriptionTask.Status.PENDING
|
||
}
|
||
if project_id:
|
||
try:
|
||
p_id = int(project_id)
|
||
# 只有当 ID > 0 时才认为是有效的项目 ID
|
||
# 避免前端传递 0 或 Swagger 默认值导致的外键约束错误
|
||
if p_id > 0:
|
||
task_data['project_id'] = p_id
|
||
except (ValueError, TypeError):
|
||
pass # Ignore invalid project_id
|
||
|
||
task_record = TranscriptionTask.objects.create(**task_data)
|
||
logger.info(f"Created TranscriptionTask {task_record.id} with project_id={project_id}")
|
||
|
||
# 3. 调用听悟接口创建任务
|
||
try:
|
||
tingwu_response = service.create_transcription_task(oss_url)
|
||
|
||
# 兼容处理响应结构,通常为 {"Data": {"TaskId": "...", ...}}
|
||
if 'Data' in tingwu_response and isinstance(tingwu_response['Data'], dict):
|
||
task_id = tingwu_response['Data'].get('TaskId')
|
||
else:
|
||
task_id = tingwu_response.get('TaskId')
|
||
|
||
if task_id:
|
||
task_record.task_id = task_id
|
||
task_record.status = TranscriptionTask.Status.PROCESSING
|
||
task_record.save()
|
||
else:
|
||
task_record.status = TranscriptionTask.Status.FAILED
|
||
task_record.error_message = "未能获取 TaskId"
|
||
task_record.save()
|
||
return Response({'error': '未能获取 TaskId'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
except Exception as e:
|
||
task_record.status = TranscriptionTask.Status.FAILED
|
||
task_record.error_message = str(e)
|
||
task_record.save()
|
||
logger.error(f"创建听悟任务失败: {e}")
|
||
return Response({'error': f"创建听悟任务失败: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
serializer = self.get_serializer(task_record)
|
||
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理上传请求失败: {e}")
|
||
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
@action(detail=True, methods=['post'])
|
||
@extend_schema(
|
||
request={
|
||
'application/json': {
|
||
'type': 'object',
|
||
'properties': {
|
||
'model_selection': {'type': 'string', 'description': '模型选择'},
|
||
'prompt': {'type': 'string', 'description': '评分提示词'},
|
||
}
|
||
}
|
||
},
|
||
responses={200: AIEvaluationSerializer(many=True)}
|
||
)
|
||
def evaluate(self, request, pk=None):
|
||
"""
|
||
触发AI评估
|
||
"""
|
||
task = self.get_object()
|
||
|
||
# 1. 如果有 active template,触发所有 active template
|
||
# 2. 如果请求体提供了 custom prompt,则创建一个 custom evaluation (no template)
|
||
|
||
from .models import AIEvaluationTemplate
|
||
from .bailian_service import BailianService
|
||
service = BailianService()
|
||
|
||
evaluations_to_process = []
|
||
|
||
# A. 如果指定了 Prompt/Model,视为手动单次评估
|
||
model_selection = request.data.get('model_selection')
|
||
prompt = request.data.get('prompt')
|
||
|
||
if prompt:
|
||
# 创建一个不关联 Template 的评估
|
||
eval, _ = AIEvaluation.objects.get_or_create(
|
||
task=task,
|
||
template=None,
|
||
defaults={
|
||
'model_selection': model_selection or 'qwen-plus',
|
||
'prompt': prompt
|
||
}
|
||
)
|
||
# 更新配置
|
||
eval.model_selection = model_selection or eval.model_selection
|
||
eval.prompt = prompt
|
||
eval.save()
|
||
evaluations_to_process.append(eval)
|
||
else:
|
||
# B. 否则触发所有 Active Templates
|
||
active_templates = AIEvaluationTemplate.objects.filter(is_active=True)
|
||
if not active_templates.exists():
|
||
return Response({'message': 'No active templates and no custom prompt provided'}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
for t in active_templates:
|
||
eval, _ = AIEvaluation.objects.get_or_create(
|
||
task=task,
|
||
template=t,
|
||
defaults={
|
||
'model_selection': t.model_selection,
|
||
'prompt': t.prompt
|
||
}
|
||
)
|
||
# 始终更新为模板最新配置? 或者保留历史? 用户意图似乎是"模版搭好...启用...生成几份"
|
||
# 这里假设触发时应用模板当前配置
|
||
eval.model_selection = t.model_selection
|
||
eval.prompt = t.prompt
|
||
eval.save()
|
||
evaluations_to_process.append(eval)
|
||
|
||
# 执行评估 (改为异步并发执行)
|
||
# 提取ID列表,避免传递模型对象导致可能的线程问题
|
||
eval_ids = [e.id for e in evaluations_to_process]
|
||
|
||
if eval_ids:
|
||
import threading
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
def run_evaluations_background(ids):
|
||
# 在后台线程中重新引入依赖
|
||
from .models import AIEvaluation
|
||
from .bailian_service import BailianService
|
||
|
||
# 为该线程创建独立的服务实例
|
||
local_service = BailianService()
|
||
|
||
# 获取最新的对象
|
||
target_evals = AIEvaluation.objects.filter(id__in=ids)
|
||
|
||
# 使用线程池并发执行
|
||
# max_workers=4 可以同时处理4个评估请求
|
||
with ThreadPoolExecutor(max_workers=4) as executor:
|
||
executor.map(local_service.evaluate_task, target_evals)
|
||
|
||
# 启动后台线程,不阻塞当前 HTTP 请求
|
||
thread = threading.Thread(target=run_evaluations_background, args=(eval_ids,))
|
||
thread.daemon = True # 设置为守护线程
|
||
thread.start()
|
||
|
||
# 返回该任务的所有评估结果
|
||
all_evals = AIEvaluation.objects.filter(task=task)
|
||
serializer = AIEvaluationSerializer(all_evals, many=True)
|
||
return Response(serializer.data)
|
||
|
||
@action(detail=True, methods=['get'])
|
||
@extend_schema(
|
||
parameters=[
|
||
OpenApiParameter("id", OpenApiTypes.UUID, OpenApiParameter.PATH, description="Task ID"),
|
||
],
|
||
responses={200: TranscriptionTaskSerializer}
|
||
)
|
||
def refresh_status(self, request, pk=None):
|
||
"""
|
||
刷新任务状态并获取结果
|
||
"""
|
||
task = self.get_object()
|
||
|
||
# 允许刷新的条件:
|
||
# 1. 任务未完成 (PENDING, PROCESSING)
|
||
# 2. 任务已完成但逐字稿 (transcription) 为空
|
||
# 3. 任务已完成但 AI总结 (summary) 为空 (新增)
|
||
|
||
should_refresh = False
|
||
if task.status not in [TranscriptionTask.Status.SUCCEEDED, TranscriptionTask.Status.FAILED]:
|
||
should_refresh = True
|
||
elif task.status == TranscriptionTask.Status.SUCCEEDED:
|
||
if not task.transcription or not task.summary:
|
||
should_refresh = True
|
||
|
||
if not should_refresh:
|
||
serializer = self.get_serializer(task)
|
||
return Response(serializer.data)
|
||
|
||
if not task.task_id:
|
||
return Response({'error': '任务ID不存在'}, status=status.HTTP_400_BAD_REQUEST)
|
||
|
||
service = AliyunTingwuService()
|
||
try:
|
||
result = service.get_task_info(task.task_id)
|
||
|
||
# 兼容处理响应结构 {"Data": {"TaskStatus": "...", "Result": ...}}
|
||
# 有些情况下 SDK 返回的是 JSON 字符串,需要二次解析
|
||
if isinstance(result, str):
|
||
import json
|
||
try:
|
||
result = json.loads(result)
|
||
except:
|
||
pass
|
||
|
||
if isinstance(result, dict):
|
||
data_obj = result.get('Data', result)
|
||
else:
|
||
data_obj = result
|
||
if not isinstance(data_obj, dict):
|
||
# 如果 Data 不是字典,可能它本身就是字符串,或者 result 结构更平铺
|
||
data_obj = result
|
||
|
||
# 防御性编程:确保 data_obj 是字典
|
||
if not isinstance(data_obj, dict):
|
||
logger.error(f"Unexpected response format: {type(data_obj)} - {data_obj}")
|
||
return Response({'error': f"Unexpected response format: {type(data_obj)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
||
|
||
# 调用 Service 进行解析和更新
|
||
service.parse_and_update_task(task, result)
|
||
|
||
# 如果任务成功但 AI 总结仍为空 (可能之前解析没触发,或者大模型调用失败)
|
||
# 再次尝试强制触发 summarize_task (如果原始数据存在)
|
||
# 注意:service.parse_and_update_task 内部已经尝试异步触发,这里作为补救措施
|
||
if task.status == TranscriptionTask.Status.SUCCEEDED and not task.summary:
|
||
if task.summary_data or task.auto_chapters_data:
|
||
try:
|
||
# 先设置状态为 "AI总结生成当中..."
|
||
task.summary = "AI总结生成当中..."
|
||
task.save(update_fields=['summary'])
|
||
|
||
# 异步触发总结生成
|
||
import threading
|
||
from .bailian_service import BailianService
|
||
|
||
def async_summarize(task_id):
|
||
try:
|
||
# 重新获取 task 对象以避免线程问题
|
||
from .models import TranscriptionTask
|
||
task_obj = TranscriptionTask.objects.get(id=task_id)
|
||
bailian_service = BailianService()
|
||
bailian_service.summarize_task(task_obj)
|
||
except Exception as e:
|
||
logger.error(f"Async summary generation failed: {e}")
|
||
|
||
threading.Thread(target=async_summarize, args=(task.id,)).start()
|
||
|
||
except Exception as e:
|
||
logger.error(f"Force trigger AI summarization failed: {e}")
|
||
|
||
# 重新获取 task 以包含更新后的关联字段
|
||
task.refresh_from_db()
|
||
|
||
serializer = self.get_serializer(task)
|
||
return Response(serializer.data)
|
||
|
||
except Exception as e:
|
||
logger.error(f"刷新任务状态失败: {e}")
|
||
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|