362 lines
17 KiB
Python
362 lines
17 KiB
Python
import logging
|
|
import uuid
|
|
from rest_framework import viewsets, status
|
|
from rest_framework.decorators import action, api_view, permission_classes, parser_classes
|
|
from rest_framework.response import Response
|
|
from rest_framework.parsers import MultiPartParser, FormParser, JSONParser
|
|
from rest_framework.permissions import AllowAny
|
|
from django.conf import settings
|
|
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes
|
|
from .models import TranscriptionTask
|
|
from .serializers import TranscriptionTaskSerializer, TranscriptionUploadSerializer
|
|
from .services import AliyunTingwuService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
@api_view(['POST'])
|
|
@permission_classes([AllowAny])
|
|
def tingwu_callback(request):
|
|
"""
|
|
处理阿里云听悟的回调消息
|
|
"""
|
|
data = request.data
|
|
logger.info(f"收到听悟回调: {data}")
|
|
|
|
# 1. 处理连通性测试消息
|
|
# 格式: {"Code": "0", "Data": {"Test": "..."}, "Message": "success", "RequestId": "..."}
|
|
if isinstance(data, dict) and 'Data' in data and 'Test' in data['Data']:
|
|
logger.info("收到听悟连通性测试请求")
|
|
return Response({'message': 'success'}, status=status.HTTP_200_OK)
|
|
|
|
# 2. 处理任务完成消息 (根据实际文档或后续调试完善)
|
|
# 通常会包含 TaskId 和 Status
|
|
# 注意:阿里云听悟回调的结构可能在 Header 或 Body 中不同,需根据实际情况调整
|
|
# 这里是一个通用的处理逻辑
|
|
task_id = data.get('TaskId')
|
|
task_status = data.get('Status')
|
|
|
|
if task_id:
|
|
try:
|
|
task = TranscriptionTask.objects.filter(task_id=task_id).first()
|
|
if task:
|
|
if task_status == 'COMPLETE':
|
|
logger.info(f"任务 {task_id} 完成,等待下一次查询刷新")
|
|
# 可以在这里直接调用 get_task_info 刷新数据,但要注意超时
|
|
elif task_status == 'FAILED':
|
|
task.status = TranscriptionTask.Status.FAILED
|
|
task.error_message = data.get('StatusText', 'Callback reported failure')
|
|
task.save()
|
|
else:
|
|
logger.warning(f"回调收到未知任务ID: {task_id}")
|
|
except Exception as e:
|
|
logger.error(f"处理回调异常: {e}")
|
|
|
|
return Response({'message': 'success'}, status=status.HTTP_200_OK)
|
|
|
|
class TranscriptionTaskViewSet(viewsets.ModelViewSet):
|
|
queryset = TranscriptionTask.objects.all()
|
|
serializer_class = TranscriptionTaskSerializer
|
|
parser_classes = (MultiPartParser, FormParser)
|
|
|
|
@extend_schema(
|
|
request={
|
|
'multipart/form-data': {
|
|
'type': 'object',
|
|
'properties': {
|
|
'file': {
|
|
'type': 'string',
|
|
'format': 'binary'
|
|
},
|
|
'file_url': {
|
|
'type': 'string',
|
|
'description': '音频文件的URL地址'
|
|
}
|
|
}
|
|
}
|
|
},
|
|
responses={201: TranscriptionTaskSerializer}
|
|
)
|
|
def create(self, request, *args, **kwargs):
|
|
"""
|
|
上传音频文件并创建听悟转写任务
|
|
"""
|
|
file_obj = request.FILES.get('file')
|
|
file_url = request.data.get('file_url')
|
|
|
|
if not file_obj and not file_url:
|
|
return Response({'error': '请提供文件或文件URL'}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
service = AliyunTingwuService()
|
|
if not service.bucket or not service.client:
|
|
return Response({'error': '阿里云服务未配置'}, status=status.HTTP_503_SERVICE_UNAVAILABLE)
|
|
|
|
try:
|
|
oss_url = None
|
|
if file_obj:
|
|
# 1. 上传文件到 OSS
|
|
file_extension = file_obj.name.split('.')[-1]
|
|
file_name = f"transcription/{uuid.uuid4()}.{file_extension}"
|
|
|
|
# 使用服务上传
|
|
oss_url = service.upload_to_oss(file_obj, file_name)
|
|
else:
|
|
# 使用提供的 URL
|
|
oss_url = file_url
|
|
|
|
# 2. 创建数据库记录
|
|
task_record = TranscriptionTask.objects.create(
|
|
file_url=oss_url,
|
|
status=TranscriptionTask.Status.PENDING
|
|
)
|
|
|
|
# 3. 调用听悟接口创建任务
|
|
try:
|
|
tingwu_response = service.create_transcription_task(oss_url)
|
|
|
|
# 兼容处理响应结构,通常为 {"Data": {"TaskId": "...", ...}}
|
|
if 'Data' in tingwu_response and isinstance(tingwu_response['Data'], dict):
|
|
task_id = tingwu_response['Data'].get('TaskId')
|
|
else:
|
|
task_id = tingwu_response.get('TaskId')
|
|
|
|
if task_id:
|
|
task_record.task_id = task_id
|
|
task_record.status = TranscriptionTask.Status.PROCESSING
|
|
task_record.save()
|
|
else:
|
|
task_record.status = TranscriptionTask.Status.FAILED
|
|
task_record.error_message = "未能获取 TaskId"
|
|
task_record.save()
|
|
return Response({'error': '未能获取 TaskId'}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
except Exception as e:
|
|
task_record.status = TranscriptionTask.Status.FAILED
|
|
task_record.error_message = str(e)
|
|
task_record.save()
|
|
logger.error(f"创建听悟任务失败: {e}")
|
|
return Response({'error': f"创建听悟任务失败: {str(e)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
serializer = self.get_serializer(task_record)
|
|
return Response(serializer.data, status=status.HTTP_201_CREATED)
|
|
|
|
except Exception as e:
|
|
logger.error(f"处理上传请求失败: {e}")
|
|
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
@action(detail=True, methods=['get'])
|
|
@extend_schema(
|
|
parameters=[
|
|
OpenApiParameter("id", OpenApiTypes.UUID, OpenApiParameter.PATH, description="Task ID"),
|
|
],
|
|
responses={200: TranscriptionTaskSerializer}
|
|
)
|
|
def refresh_status(self, request, pk=None):
|
|
"""
|
|
刷新任务状态并获取结果
|
|
"""
|
|
task = self.get_object()
|
|
|
|
# 如果任务已经完成或失败,但逐字稿为空,允许重新刷新
|
|
if task.status == TranscriptionTask.Status.SUCCEEDED and not task.transcription:
|
|
pass # 继续执行刷新逻辑
|
|
elif task.status in [TranscriptionTask.Status.SUCCEEDED, TranscriptionTask.Status.FAILED]:
|
|
serializer = self.get_serializer(task)
|
|
return Response(serializer.data)
|
|
|
|
if not task.task_id:
|
|
return Response({'error': '任务ID不存在'}, status=status.HTTP_400_BAD_REQUEST)
|
|
|
|
service = AliyunTingwuService()
|
|
try:
|
|
result = service.get_task_info(task.task_id)
|
|
|
|
# 兼容处理响应结构 {"Data": {"TaskStatus": "...", "Result": ...}}
|
|
# 有些情况下 SDK 返回的是 JSON 字符串,需要二次解析
|
|
if isinstance(result, str):
|
|
import json
|
|
try:
|
|
result = json.loads(result)
|
|
except:
|
|
pass
|
|
|
|
if isinstance(result, dict):
|
|
data_obj = result.get('Data', result)
|
|
else:
|
|
data_obj = result
|
|
if not isinstance(data_obj, dict):
|
|
# 如果 Data 不是字典,可能它本身就是字符串,或者 result 结构更平铺
|
|
data_obj = result
|
|
|
|
# 防御性编程:确保 data_obj 是字典
|
|
if not isinstance(data_obj, dict):
|
|
logger.error(f"Unexpected response format: {type(data_obj)} - {data_obj}")
|
|
return Response({'error': f"Unexpected response format: {type(data_obj)}"}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
|
|
task_status = data_obj.get('TaskStatus')
|
|
|
|
# 兼容其他状态字段名
|
|
if not task_status:
|
|
task_status = data_obj.get('Status')
|
|
|
|
if task_status == 'COMPLETE' or task_status == 'COMPLETED' or task_status == 'SUCCEEDED':
|
|
task.status = TranscriptionTask.Status.SUCCEEDED
|
|
|
|
# 解析结果
|
|
task_result = data_obj.get('Result', {})
|
|
logger.info(f"Task result keys: {task_result.keys()}")
|
|
|
|
# 提取逐字稿
|
|
transcription_data = task_result.get('Transcription', {})
|
|
logger.info(f"Raw transcription data type: {type(transcription_data)}")
|
|
|
|
# 如果是 URL (字符串),尝试下载内容
|
|
if isinstance(transcription_data, str) and transcription_data.startswith('http'):
|
|
try:
|
|
import requests
|
|
logger.info(f"Downloading transcription from {transcription_data}")
|
|
t_resp = requests.get(transcription_data)
|
|
if t_resp.status_code == 200:
|
|
transcription_data = t_resp.json()
|
|
logger.info(f"Downloaded transcription keys: {transcription_data.keys() if isinstance(transcription_data, dict) else 'Not a dict'}")
|
|
# 保存原始数据
|
|
task.transcription_data = transcription_data
|
|
else:
|
|
logger.warning(f"Failed to download transcription: {t_resp.status_code}")
|
|
transcription_data = {}
|
|
except Exception as e:
|
|
logger.error(f"Error downloading transcription: {e}")
|
|
transcription_data = {}
|
|
elif isinstance(transcription_data, dict) and 'TranscriptionUrl' in transcription_data:
|
|
# 有些情况 Transcription 还是对象,但内容在 Url 字段
|
|
try:
|
|
import requests
|
|
url = transcription_data['TranscriptionUrl']
|
|
logger.info(f"Downloading transcription from {url}")
|
|
t_resp = requests.get(url)
|
|
if t_resp.status_code == 200:
|
|
transcription_data = t_resp.json()
|
|
logger.info(f"Downloaded transcription keys: {transcription_data.keys() if isinstance(transcription_data, dict) else 'Not a dict'}")
|
|
# 保存原始数据
|
|
task.transcription_data = transcription_data
|
|
except Exception as e:
|
|
logger.error(f"Error downloading transcription nested url: {e}")
|
|
|
|
if isinstance(transcription_data, dict):
|
|
# 确定包含实际内容的字典源
|
|
content_source = transcription_data
|
|
|
|
# 关键修复:
|
|
# 阿里云返回的 JSON 可能是 {"Transcription": {"Sentences": ...}} 也可能是 {"Sentences": ...}
|
|
# 之前的逻辑虽然尝试了 content_source = transcription_data['Transcription'],但如果 key 不存在会报错
|
|
# 且如果是 {"TaskId": "...", "Transcription": {"Sentences": ...}} 这种结构,需要先剥离外层
|
|
|
|
# 尝试找到真正的 sentences/paragraphs 所在的字典
|
|
# 优先查找 'Transcription' 键,如果它对应的是字典,那么数据很可能在里面
|
|
if 'Transcription' in content_source and isinstance(content_source['Transcription'], dict):
|
|
content_source = content_source['Transcription']
|
|
logger.info(f"Drilled down to nested 'Transcription' key. Keys: {content_source.keys()}")
|
|
|
|
# 尝试提取 Sentences
|
|
sentences = content_source.get('Sentences', [])
|
|
|
|
# 尝试提取 Paragraphs
|
|
paragraphs_data = content_source.get('Paragraphs', [])
|
|
|
|
if sentences:
|
|
full_text = " ".join([s.get('Text', '') for s in sentences])
|
|
task.transcription = full_text
|
|
elif paragraphs_data:
|
|
# 处理 Paragraphs
|
|
para_list = []
|
|
if isinstance(paragraphs_data, dict):
|
|
# 有时结构是 {"Paragraphs": {"Paragraphs": [...]}} 或者 {"Paragraphs": [...]}
|
|
para_list = paragraphs_data.get('Paragraphs', [])
|
|
if not para_list and isinstance(paragraphs_data, list):
|
|
para_list = paragraphs_data
|
|
elif isinstance(paragraphs_data, list):
|
|
para_list = paragraphs_data
|
|
|
|
if para_list:
|
|
texts = []
|
|
for p in para_list:
|
|
if 'Text' in p:
|
|
texts.append(p['Text'])
|
|
elif 'Sentences' in p:
|
|
for s in p['Sentences']:
|
|
if 'Text' in s:
|
|
texts.append(s['Text'])
|
|
task.transcription = "\n".join(texts)
|
|
logger.info(f"Extracted {len(texts)} paragraphs")
|
|
else:
|
|
logger.warning(f"Paragraphs found but failed to extract list. Type: {type(paragraphs_data)}")
|
|
else:
|
|
logger.warning(f"Could not find Sentences or Paragraphs in content source. Keys: {content_source.keys()}")
|
|
|
|
# 提取总结
|
|
# 总结结果结构可能因配置不同而异,这里尝试获取摘要
|
|
summarization = task_result.get('Summarization', {})
|
|
|
|
# 如果是 URL (字符串),尝试下载内容
|
|
if isinstance(summarization, str) and summarization.startswith('http'):
|
|
try:
|
|
import requests
|
|
logger.info(f"Downloading summarization from {summarization}")
|
|
s_resp = requests.get(summarization)
|
|
if s_resp.status_code == 200:
|
|
summarization = s_resp.json()
|
|
# 保存原始数据
|
|
task.summary_data = summarization
|
|
else:
|
|
logger.warning(f"Failed to download summarization: {s_resp.status_code}")
|
|
summarization = {}
|
|
except Exception as e:
|
|
logger.error(f"Error downloading summarization: {e}")
|
|
summarization = {}
|
|
|
|
# 听悟的总结通常在 Summarization.Text 或类似字段
|
|
# 如果是章节摘要,可能在 Chapters 中
|
|
# 假设是全文摘要
|
|
if 'Text' in summarization:
|
|
task.summary = summarization['Text']
|
|
elif 'Headline' in summarization:
|
|
task.summary = summarization['Headline']
|
|
else:
|
|
# 尝试从章节摘要中提取
|
|
chapters = task_result.get('Chapters', [])
|
|
# 处理 AutoChapters
|
|
auto_chapters = task_result.get('AutoChapters', {})
|
|
if isinstance(auto_chapters, str) and auto_chapters.startswith('http'):
|
|
try:
|
|
import requests
|
|
logger.info(f"Downloading auto chapters from {auto_chapters}")
|
|
ac_resp = requests.get(auto_chapters)
|
|
if ac_resp.status_code == 200:
|
|
auto_chapters = ac_resp.json()
|
|
task.auto_chapters_data = auto_chapters
|
|
except Exception as e:
|
|
logger.error(f"Error downloading auto chapters: {e}")
|
|
|
|
summary_parts = []
|
|
for chapter in chapters:
|
|
if 'Headline' in chapter:
|
|
summary_parts.append(chapter['Headline'])
|
|
if 'Summary' in chapter:
|
|
summary_parts.append(chapter['Summary'])
|
|
task.summary = "\n".join(summary_parts)
|
|
|
|
task.save()
|
|
|
|
elif task_status == 'FAILED':
|
|
task.status = TranscriptionTask.Status.FAILED
|
|
task.error_message = data_obj.get('TaskStatusText', result.get('Message', 'Unknown error'))
|
|
task.save()
|
|
|
|
# 其他状态 (PENDING, RUNNING) 不做更改
|
|
|
|
serializer = self.get_serializer(task)
|
|
return Response(serializer.data)
|
|
|
|
except Exception as e:
|
|
logger.error(f"刷新任务状态失败: {e}")
|
|
return Response({'error': str(e)}, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
|