Files
Scoring-System/backend/ai_services/services.py

421 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import json
import logging
import time
import uuid
import oss2
from aliyunsdkcore.client import AcsClient
from aliyunsdkcore.acs_exception.exceptions import ClientException, ServerException
# 尝试导入最新的 API 版本,如果有问题可能需要调整
try:
from aliyunsdktingwu.request.v20230930 import CreateTaskRequest, GetTaskInfoRequest
except ImportError:
# Fallback or error handling if version differs
pass
from django.conf import settings
logger = logging.getLogger(__name__)
from .models import TranscriptionTask, AIEvaluation, AIEvaluationTemplate
class AliyunTingwuService:
def __init__(self):
self.access_key_id = settings.ALIYUN_ACCESS_KEY_ID
self.access_key_secret = settings.ALIYUN_ACCESS_KEY_SECRET
self.oss_bucket_name = settings.ALIYUN_OSS_BUCKET_NAME
self.oss_endpoint = settings.ALIYUN_OSS_ENDPOINT
self.tingwu_app_key = settings.ALIYUN_TINGWU_APP_KEY
self.region_id = "cn-shanghai" # 听悟服务区域根据文档应与OSS区域一致或者使用 'cn-beijing'
# 初始化 OSS Bucket
if self.access_key_id and self.access_key_secret and self.oss_endpoint:
auth = oss2.Auth(self.access_key_id, self.access_key_secret)
self.bucket = oss2.Bucket(auth, self.oss_endpoint, self.oss_bucket_name)
else:
self.bucket = None
logger.warning("Aliyun OSS configuration missing.")
# 初始化听悟 Client
if self.access_key_id and self.access_key_secret:
self.client = AcsClient(
self.access_key_id,
self.access_key_secret,
self.region_id
)
# 显式添加听悟服务的 Endpoint 映射,解决 EndpointResolvingError
# 听悟 API 的服务接入点通常是 tingwu.cn-beijing.aliyuncs.com
# 但新版听悟 API (tingwu.aliyuncs.com) 可能不同,需根据实际情况添加
# 这里添加一个通用的 Endpoint 映射
try:
# 尝试为 tingwu 产品设置 Endpoint
# 注意听悟服务主要部署在北京Endpoint 通常为 tingwu.cn-beijing.aliyuncs.com
# 如果您的服务在上海,也可能需要连接到北京的接入点
self.client.add_endpoint(self.region_id, "tingwu", "tingwu.cn-beijing.aliyuncs.com")
except Exception as e:
logger.warning(f"Failed to add endpoint: {e}")
else:
self.client = None
logger.warning("Aliyun AccessKey configuration missing.")
def upload_to_oss(self, file_obj, file_name, day=7):
"""
上传文件到 OSS 并返回带签名的 URL
默认生成有效期为 7 天 (3600 * 24 * day) 的签名URL方便评委在一段时间内都能播放。
"""
if not self.bucket:
raise Exception("OSS Client not initialized")
try:
# 上传文件
# file_obj 应该是打开的文件对象或字节流
self.bucket.put_object(file_name, file_obj)
# 生成签名 URL有效期 7 天 (3600 * 24 * 7 = 604800 秒)
url = self.bucket.sign_url('GET', file_name, 3600 * 24 * day)
return url
except Exception as e:
logger.error(f"OSS Upload failed: {e}")
raise e
def create_transcription_task(self, file_url, language="cn"):
"""
创建听悟转写任务
"""
if not self.client:
raise Exception("Tingwu Client not initialized")
request = CreateTaskRequest.CreateTaskRequest()
# 针对阿里云 SDK 不同版本的兼容性处理
# "type" 参数是听悟 API (ROA 风格) 的必填项,用于指定任务类型
# 根据官方文档,离线任务的 type 通常就是 'offline'
request.add_query_param('type', 'offline')
# 构造请求体 (Body)
# 根据听悟 API 文档AppKey, Input, Parameters 应位于 JSON Body 中
# 而不是 Query Parameter
body = {
"AppKey": self.tingwu_app_key,
"Input": {
"FileUrl": file_url,
"SourceLanguage": language,
"TaskKey": str(uuid.uuid4())
},
"Parameters": {
"Transcoding": {
"TargetAudioFormat": "mp3"
},
"Transcription": {
"DiarizationEnabled": True,
"ChannelId": 0
},
"TranscriptionEnabled": True,
"AutoChaptersEnabled": True,
"SummarizationEnabled": True,
"Summarization": {
"Types": ["Paragraph", "Conversational", "QuestionsAnswering", "MindMap"]
}
}
}
# 设置 Body 内容
request.set_content(json.dumps(body))
request.add_header('Content-Type', 'application/json')
# 强制设置 Endpoint避免 SDK.EndpointResolvingError
# 听悟目前主要服务点在北京
request.set_endpoint("tingwu.cn-beijing.aliyuncs.com")
# 显式设置 Method 为 PUT
request.set_method('PUT')
try:
response = self.client.do_action_with_exception(request)
return json.loads(response)
except (ClientException, ServerException) as e:
logger.error(f"Tingwu CreateTask failed: {e}")
raise e
def get_task_info(self, task_id):
"""
查询任务状态和结果
"""
if not self.client:
raise Exception("Tingwu Client not initialized")
request = GetTaskInfoRequest.GetTaskInfoRequest()
request.set_TaskId(task_id)
try:
response = self.client.do_action_with_exception(request)
return json.loads(response)
except (ClientException, ServerException) as e:
logger.error(f"Tingwu GetTaskInfo failed: {e}")
raise e
def parse_and_update_task(self, task, result):
"""
解析听悟结果并更新任务
:param task: TranscriptionTask 实例
:param result: get_task_info 返回的完整 JSON (或 Data 部分)
"""
# 记录之前的状态,用于判断是否是首次完成
previous_status = task.status
# 1. 提取 Data 对象
if isinstance(result, dict):
data_obj = result.get('Data', result)
else:
data_obj = result
if not isinstance(data_obj, dict):
logger.error(f"Unexpected data format: {type(data_obj)}")
return
# 2. 更新状态
task_status = data_obj.get('TaskStatus') or data_obj.get('Status')
if task_status in ['COMPLETE', 'COMPLETED', 'SUCCEEDED']:
task.status = 'SUCCEEDED' # 使用字符串引用,避免导入模型循环引用
elif task_status == 'FAILED':
task.status = 'FAILED'
task.error_message = data_obj.get('TaskStatusText', data_obj.get('Message', 'Unknown error'))
task.save()
return
else:
# 仍在处理中,不更新内容
return
# 3. 解析结果
task_result = data_obj.get('Result', {})
# 兼容处理:如果 Result 为空,或者不存在,尝试直接使用 data_obj 作为结果源
# 某些情况下Summarization/AutoChapters 可能直接位于 Data 层级
if not task_result:
task_result = data_obj
# 辅助函数:从源字典或其 Result 子字典中获取字段
def get_data_field(source, key):
# 1. 尝试直接从 task_result 获取 (如果 task_result 就是 Data 本身,这里也会生效)
if isinstance(source, dict) and key in source:
return source[key]
# 2. 如果 source 是 Data尝试从 source['Result'] 获取
if isinstance(source, dict) and 'Result' in source and isinstance(source['Result'], dict):
if key in source['Result']:
return source['Result'][key]
return None
# --- A. 处理逐字稿 (Transcription) ---
transcription_data = get_data_field(task_result, 'Transcription') or get_data_field(data_obj, 'Transcription') or {}
# 处理 URL 下载
if isinstance(transcription_data, str) and transcription_data.startswith('http'):
try:
import requests
t_resp = requests.get(transcription_data)
if t_resp.status_code == 200:
transcription_data = t_resp.json()
except Exception as e:
logger.error(f"Download transcription failed: {e}")
transcription_data = {}
elif isinstance(transcription_data, dict) and 'TranscriptionUrl' in transcription_data:
try:
import requests
t_resp = requests.get(transcription_data['TranscriptionUrl'])
if t_resp.status_code == 200:
transcription_data = t_resp.json()
except Exception as e:
logger.error(f"Download transcription url failed: {e}")
# 保存原始数据
task.transcription_data = transcription_data
# 提取文本
# 结构: {"Transcription": {"Paragraphs": [{"Words": [{"Text": "..."}]}]}}
# 或直接 {"Paragraphs": ...}
content_source = transcription_data
if 'Transcription' in content_source and isinstance(content_source['Transcription'], dict):
content_source = content_source['Transcription']
paragraphs = content_source.get('Paragraphs', [])
full_text_lines = []
if paragraphs and isinstance(paragraphs, list):
for p in paragraphs:
# 尝试从 Words 中提取
words = p.get('Words', [])
if words:
line_text = "".join([str(w.get('Text', '')) for w in words])
full_text_lines.append(line_text)
# 兼容旧结构或直接 Text
elif 'Text' in p:
full_text_lines.append(p['Text'])
if full_text_lines:
task.transcription = "\n".join(full_text_lines)
# --- B. 处理 AI 总结 (Summarization) ---
summarization = get_data_field(task_result, 'Summarization') or get_data_field(data_obj, 'Summarization') or {}
# 处理 URL 下载
if isinstance(summarization, str) and summarization.startswith('http'):
try:
import requests
s_resp = requests.get(summarization)
if s_resp.status_code == 200:
summarization = s_resp.json()
except Exception as e:
logger.error(f"Download summarization failed: {e}")
summarization = {}
# 保存原始数据
task.summary_data = summarization
# 提取文本 (MindMapSummary)
# 结构: {"MindMapSummary": [{"Title": "...", "Topic": [...]}]}
# 移除了原先的 summary_text 拼接逻辑
# --- C. 处理章节 (AutoChapters) ---
auto_chapters = get_data_field(task_result, 'AutoChapters') or get_data_field(data_obj, 'AutoChapters') or []
# 处理 URL 下载
if isinstance(auto_chapters, str) and auto_chapters.startswith('http'):
try:
import requests
ac_resp = requests.get(auto_chapters)
if ac_resp.status_code == 200:
auto_chapters = ac_resp.json()
except Exception as e:
logger.error(f"Download auto chapters failed: {e}")
auto_chapters = []
# 保存原始数据
task.auto_chapters_data = auto_chapters
# 保存任务,确保原始数据已写入数据库
task.save()
# 调用大模型生成总结 (如果 summary_data 或 auto_chapters_data 存在)
if task.summary_data or task.auto_chapters_data:
try:
# 设置占位状态
task.summary = "AI总结生成当中..."
task.save(update_fields=['summary'])
# 异步执行总结
import threading
from .bailian_service import BailianService
def async_summarize_in_service(task_id):
try:
# 重新获取 task 以避免线程安全问题
from .models import TranscriptionTask
t = TranscriptionTask.objects.get(id=task_id)
bailian_service = BailianService()
bailian_service.summarize_task(t)
except Exception as e:
logger.error(f"Async summary generation failed in service: {e}")
threading.Thread(target=async_summarize_in_service, args=(task.id,)).start()
logger.info(f"Triggered async summary generation for task {task.id}")
except Exception as e:
logger.error(f"Failed to trigger AI summarization: {e}")
# 4. 自动触发 AI 评估 (如果任务首次成功且有启用的模板)
if previous_status != 'SUCCEEDED' and task.status == 'SUCCEEDED' and task.transcription:
# 同样改为异步触发,传递 task.id 以避免线程中的对象状态问题
import threading
threading.Thread(target=self.trigger_ai_evaluations, args=(task.id,)).start()
def trigger_ai_evaluations(self, task_id):
"""
根据启用的模板自动触发 AI 评估
逻辑:
1. 如果模板关联了评分维度(s score_dimension),只对关联了相同维度的比赛进行评估
2. 如果模板未关联评分维度:
- 如果是默认模板(is_default=True),评价所有比赛
- 否则不进行自动评价
"""
try:
# 在线程中重新获取 task 对象,并预加载 project避免懒加载导致的线程数据库连接问题
from .models import TranscriptionTask
task = TranscriptionTask.objects.select_related('project', 'project__competition').get(id=task_id)
except Exception as e:
# 兼容处理:如果 task_id 其实是 task 对象(虽然我们上面改了,但防止其他地方调用传错)
if hasattr(task_id, 'id'):
try:
from .models import TranscriptionTask
task = TranscriptionTask.objects.select_related('project', 'project__competition').get(id=task_id.id)
except:
task = task_id
else:
logger.error(f"Failed to retrieve task {task_id}: {e}")
return
active_templates = AIEvaluationTemplate.objects.filter(is_active=True)
if not active_templates.exists():
logger.info("No active AI evaluation templates found.")
return
from .bailian_service import BailianService
service = BailianService()
for template in active_templates:
# 检查是否已经存在相同的评估,避免重复创建
if AIEvaluation.objects.filter(task=task, template=template).exists():
logger.info(f"Evaluation for task {task.id} and template {template.name} already exists.")
continue
# 获取任务关联的比赛
task_competition = None
if task.project and task.project.competition:
task_competition = task.project.competition
# 判断是否应该对此任务进行评估
should_evaluate = False
if template.score_dimension:
# 模板关联了评分维度,只对关联了相同维度的比赛进行评估
if task_competition:
# 获取该比赛下所有关联了相同评分维度的比赛ID列表
from competition.models import ScoreDimension
related_competition_ids = ScoreDimension.objects.filter(
id=template.score_dimension.id
).values_list('competition_id', flat=True)
if task_competition.id in related_competition_ids:
should_evaluate = True
logger.info(f"Template '{template.name}' is linked to score_dimension, task's competition matches.")
else:
logger.info(f"Template '{template.name}' is linked to score_dimension, but task's competition does not match. Skipping.")
else:
logger.info(f"Task {task.id} has no associated competition. Skipping template '{template.name}'.")
else:
# 模板未关联评分维度,只有默认模板才评价所有比赛
if template.is_default:
should_evaluate = True
logger.info(f"Template '{template.name}' is default template, evaluating all competitions.")
else:
logger.info(f"Template '{template.name}' is not linked to score_dimension and is not default. Skipping.")
if not should_evaluate:
continue
# 创建评估记录
evaluation = AIEvaluation.objects.create(
task=task,
template=template,
model_selection=template.model_selection,
prompt=template.prompt,
status=AIEvaluation.Status.PENDING
)
# 触发评估
try:
service.evaluate_task(evaluation)
logger.info(f"Triggered evaluation {evaluation.id} for template {template.name}")
except Exception as e:
logger.error(f"Failed to trigger evaluation {evaluation.id}: {e}")