Merge branch 'master-jdk21-ai' into master-jdk21-ai-write

This commit is contained in:
xiaoxin 2024-07-03 09:18:18 +08:00
commit aa4c1cb268
95 changed files with 5707 additions and 1654 deletions

View File

@ -29,13 +29,17 @@ public interface ErrorCodeConstants {
// ========== API 聊天消息 1-040-004-000 ==========
ErrorCode AI_CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode AI_CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
ErrorCode CHAT_MESSAGE_NOT_EXIST = new ErrorCode(1_040_004_000, "消息不存在!");
ErrorCode CHAT_STREAM_ERROR = new ErrorCode(1_040_004_001, "Stream 对话异常!");
// ========== API 绘画 1-040-005-000 ==========
ErrorCode AI_IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode AI_IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode IMAGE_NOT_EXISTS = new ErrorCode(1_022_005_000, "图片不存在!");
ErrorCode IMAGE_MIDJOURNEY_SUBMIT_FAIL = new ErrorCode(1_022_005_001, "Midjourney 提交失败!原因:{}");
ErrorCode IMAGE_CUSTOM_ID_NOT_EXISTS = new ErrorCode(1_022_005_002, "Midjourney 按钮 customId 不存在! {}");
ErrorCode IMAGE_FAIL = new ErrorCode(1_022_005_002, "图片绘画失败! {}");
// ========== API 音乐 1-040-006-000 ==========
ErrorCode MUSIC_NOT_EXISTS = new ErrorCode(1_022_006_000, "音乐不存在!");
}

View File

@ -3,6 +3,7 @@ package cn.iocoder.yudao.module.ai.enums.model;
import lombok.AllArgsConstructor;
import lombok.Getter;
// TODO @芋艿可以考虑清理掉
/**
* ai 模型
*

View File

@ -1,8 +1,11 @@
package cn.iocoder.yudao.module.ai.enums.music;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 音乐状态的枚举
*
@ -10,10 +13,10 @@ import lombok.Getter;
*/
@AllArgsConstructor
@Getter
public enum AiMusicGenerateModeEnum {
public enum AiMusicGenerateModeEnum implements IntArrayValuable {
LYRIC(1, "歌词模式"),
DESCRIPTION(2, "描述模式");
DESCRIPTION(1, "描述模式"),
LYRIC(2, "歌词模式");
/**
* 模式
@ -24,4 +27,11 @@ public enum AiMusicGenerateModeEnum {
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiMusicGenerateModeEnum::getMode).toArray();
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -1,8 +1,11 @@
package cn.iocoder.yudao.module.ai.enums.music;
import cn.iocoder.yudao.framework.common.core.IntArrayValuable;
import lombok.AllArgsConstructor;
import lombok.Getter;
import java.util.Arrays;
/**
* AI 音乐状态的枚举
*
@ -10,10 +13,11 @@ import lombok.Getter;
*/
@AllArgsConstructor
@Getter
public enum AiMusicStatusEnum {
public enum AiMusicStatusEnum implements IntArrayValuable {
IN_PROGRESS(10, "进行中"),
SUCCESS(20, "已完成");
SUCCESS(20, "已完成"),
FAIL(30, "已失败");
/**
* 状态
@ -25,4 +29,11 @@ public enum AiMusicStatusEnum {
*/
private final String name;
public static final int[] ARRAYS = Arrays.stream(values()).mapToInt(AiMusicStatusEnum::getStatus).toArray();
@Override
public int[] array() {
return ARRAYS;
}
}

View File

@ -9,7 +9,7 @@ import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageRespVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -25,6 +25,8 @@ import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.validation.annotation.Validated;
import org.springframework.web.bind.annotation.*;
import java.util.List;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.security.core.util.SecurityFrameworkUtils.getLoginUserId;
@ -37,15 +39,16 @@ public class AiImageController {
@Resource
private AiImageService imageService;
@Operation(summary = "获取【我的】绘图分页")
@GetMapping("/my-page")
@Operation(summary = "获取【我的】绘图分页")
public CommonResult<PageResult<AiImageRespVO>> getImagePageMy(@Validated PageParam pageReqVO) {
PageResult<AiImageDO> pageResult = imageService.getImagePageMy(getLoginUserId(), pageReqVO);
return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
}
@Operation(summary = "获取【我的】绘图记录")
@GetMapping("/get-my")
@Operation(summary = "获取【我的】绘图记录")
@Parameter(name = "id", required = true, description = "绘画编号", example = "1024")
public CommonResult<AiImageRespVO> getImageMy(@RequestParam("id") Long id) {
AiImageDO image = imageService.getImage(id);
if (image == null || ObjUtil.notEqual(getLoginUserId(), image.getUserId())) {
@ -54,6 +57,15 @@ public class AiImageController {
return success(BeanUtils.toBean(image, AiImageRespVO.class));
}
@GetMapping("/my-list-by-ids")
@Operation(summary = "获取【我的】绘图记录列表")
@Parameter(name = "ids", required = true, description = "绘画编号数组", example = "1024,2048")
public CommonResult<List<AiImageRespVO>> getImageListMyByIds(@RequestParam("ids") List<Long> ids) {
List<AiImageDO> imageList = imageService.getImageList(ids);
imageList.removeIf(item -> !ObjUtil.equal(getLoginUserId(), item.getUserId()));
return success(BeanUtils.toBean(imageList, AiImageRespVO.class));
}
@Operation(summary = "生成图片")
@PostMapping("/draw")
public CommonResult<Long> drawImage(@Validated @RequestBody AiImageDrawReqVO drawReqVO) {
@ -102,11 +114,11 @@ public class AiImageController {
return success(BeanUtils.toBean(pageResult, AiImageRespVO.class));
}
@PutMapping("/update-public-status")
@Operation(summary = "更新绘画发布状态")
@PutMapping("/update")
@Operation(summary = "更新绘画")
@PreAuthorize("@ss.hasPermission('ai:image:update')")
public CommonResult<Boolean> updateImagePublicStatus(@Valid @RequestBody AiImageUpdatePublicStatusReqVO updateReqVO) {
imageService.updateImagePublicStatus(updateReqVO);
public CommonResult<Boolean> updateImage(@Valid @RequestBody AiImageUpdateReqVO updateReqVO) {
imageService.updateImage(updateReqVO);
return success(true);
}

View File

@ -20,7 +20,7 @@ public class AiImagePageReqVO extends PageParam {
@Schema(description = "用户编号", example = "28987")
private Long userId;
@Schema(description = "平台")
@Schema(description = "平台", example = "OpenAI")
private String platform;
@Schema(description = "绘画状态", example = "1")

View File

@ -4,15 +4,15 @@ import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 绘画修改发布状态 Request VO")
@Schema(description = "管理后台 - AI 绘画修改 Request VO")
@Data
public class AiImageUpdatePublicStatusReqVO {
public class AiImageUpdateReqVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "是否发布", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
@NotNull(message = "是否发布不能为空")
@Schema(description = "是否发布", example = "true")
private Boolean publicStatus;
}

View File

@ -31,8 +31,7 @@ public class AiMidjourneyImagineReqVO {
@NotEmpty(message = "版本号不能为空")
private String version;
// TODO @fan参考图建议用 referImageUrl
@Schema(description = "垫图(参考图)base64数组")
private List<String> base64Array;
@Schema(description = "参考图")
private String referImageUrl;
}

View File

@ -1,13 +0,0 @@
### 生成音乐Suno +
POST {{baseUrl}}/ai/music/generate
Content-Type: application/json
Authorization: {{token}}
{
"platform": "Suno",
"generateMode": 1,
"prompt": "来一首快乐的歌曲",
"modelVersion": "chirp-v3.5",
"tags": ["Happy"],
"title": "Happy Song"
}

View File

@ -0,0 +1,26 @@
### 生成音乐Suno + 歌词模式
POST {{baseUrl}}/ai/music/generate
Content-Type: application/json
Authorization: {{token}}
{
"platform": "Suno",
"generateMode": 2,
"prompt": "周末啦!",
"model": "chirp-v3.5",
"tags": ["Happy"],
"title": "Happy Song"
}
### 生成音乐Suno + 描述模式
POST {{baseUrl}}/ai/music/generate
Content-Type: application/json
Authorization: {{token}}
{
"platform": "Suno",
"generateMode": 1,
"model": "chirp-v3.5",
"gptDescriptionPrompt": "今天是星球六,结果是个下雨天,希望心情很美丽",
"makeInstrumental": false
}

View File

@ -1,16 +1,19 @@
package cn.iocoder.yudao.module.ai.controller.admin.music;
import cn.hutool.core.util.ObjUtil;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.service.music.AiMusicService;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.Parameter;
import io.swagger.v3.oas.annotations.tags.Tag;
import jakarta.annotation.Resource;
import jakarta.validation.Valid;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.security.access.prepost.PreAuthorize;
import org.springframework.web.bind.annotation.*;
import java.util.List;
@ -25,10 +28,71 @@ public class AiMusicController {
@Resource
private AiMusicService musicService;
@GetMapping("/my-page")
@Operation(summary = "获得【我的】音乐分页")
public CommonResult<PageResult<AiMusicRespVO>> getMusicMyPage(@Valid AiMusicPageReqVO pageReqVO) {
PageResult<AiMusicDO> pageResult = musicService.getMusicMyPage(pageReqVO, getLoginUserId());
return success(BeanUtils.toBean(pageResult, AiMusicRespVO.class));
}
@PostMapping("/generate")
@Operation(summary = "音乐生成")
public CommonResult<List<Long>> generateMusic(@RequestBody @Valid AiSunoGenerateReqVO reqVO) {
return success(musicService.generateMusic(getLoginUserId(), reqVO));
}
@Operation(summary = "删除【我的】音乐记录")
@DeleteMapping("/delete-my")
@Parameter(name = "id", required = true, description = "音乐编号", example = "1024")
public CommonResult<Boolean> deleteMusicMy(@RequestParam("id") Long id) {
musicService.deleteMusicMy(id, getLoginUserId());
return success(true);
}
@GetMapping("/get-my")
@Operation(summary = "获取【我的】音乐")
@Parameter(name = "id", required = true, description = "音乐编号", example = "1024")
public CommonResult<AiMusicRespVO> getMusicMy(@RequestParam("id") Long id) {
AiMusicDO music = musicService.getMusic(id);
if (music == null || ObjUtil.notEqual(getLoginUserId(), music.getUserId())) {
return success(null);
}
return success(BeanUtils.toBean(music, AiMusicRespVO.class));
}
@PostMapping("/update-my")
@Operation(summary = "修改【我的】音乐 目前只支持修改标题")
@Parameter(name = "title", required = true, description = "音乐名称", example = "夜空中最亮的星")
public CommonResult<Boolean> updateMy(AiMusicUpdateReqVO updateReqVO) {
musicService.updateMyMusic(updateReqVO, getLoginUserId());
return success(true);
}
// ================ 音乐管理 ================
@GetMapping("/page")
@Operation(summary = "获得音乐分页")
@PreAuthorize("@ss.hasPermission('ai:music:query')")
public CommonResult<PageResult<AiMusicRespVO>> getMusicPage(@Valid AiMusicPageReqVO pageReqVO) {
PageResult<AiMusicDO> pageResult = musicService.getMusicPage(pageReqVO);
return success(BeanUtils.toBean(pageResult, AiMusicRespVO.class));
}
@DeleteMapping("/delete")
@Operation(summary = "删除音乐")
@Parameter(name = "id", description = "编号", required = true)
@PreAuthorize("@ss.hasPermission('ai:music:delete')")
public CommonResult<Boolean> deleteMusic(@RequestParam("id") Long id) {
musicService.deleteMusic(id);
return success(true);
}
@PutMapping("/update")
@Operation(summary = "更新音乐")
@PreAuthorize("@ss.hasPermission('ai:music:update')")
public CommonResult<Boolean> updateMusic(@Valid @RequestBody AiMusicUpdateReqVO updateReqVO) {
musicService.updateMusic(updateReqVO);
return success(true);
}
}

View File

@ -0,0 +1,44 @@
package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.validation.InEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.ToString;
import org.springframework.format.annotation.DateTimeFormat;
import java.time.LocalDateTime;
import static cn.iocoder.yudao.framework.common.util.date.DateUtils.FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND;
@Schema(description = "管理后台 - AI 音乐分页 Request VO")
@Data
@EqualsAndHashCode(callSuper = true)
@ToString(callSuper = true)
public class AiMusicPageReqVO extends PageParam {
@Schema(description = "用户编号", example = "12212")
private Long userId;
@Schema(description = "音乐名称", example = "夜空中最亮的星")
private String title;
@Schema(description = "音乐状态", example = "20")
@InEnum(AiMusicStatusEnum.class)
private Integer status;
@Schema(description = "生成模式", example = "1")
@InEnum(AiMusicGenerateModeEnum.class)
private Integer generateMode;
@Schema(description = "是否发布", example = "true")
private Boolean publicStatus;
@Schema(description = "创建时间")
@DateTimeFormat(pattern = FORMAT_YEAR_MONTH_DAY_HOUR_MINUTE_SECOND)
private LocalDateTime[] createTime;
}

View File

@ -0,0 +1,70 @@
package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import lombok.Data;
import java.time.LocalDateTime;
import java.util.List;
@Schema(description = "管理后台 - AI 音乐 Response VO")
@Data
public class AiMusicRespVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "24790")
private Long id;
@Schema(description = "用户编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "12212")
private Long userId;
@Schema(description = "音乐名称", requiredMode = Schema.RequiredMode.REQUIRED, example = "夜空中最亮的星")
private String title;
@Schema(description = "歌词", example = "oh~卖糕的")
private String lyric;
@Schema(description = "图片地址", example = "https://www.iocoder.cn")
private String imageUrl;
@Schema(description = "音频地址", example = "https://www.iocoder.cn")
private String audioUrl;
@Schema(description = "视频地址", example = "https://www.iocoder.cn")
private String videoUrl;
@Schema(description = "音乐状态", requiredMode = Schema.RequiredMode.REQUIRED, example = "20")
private Integer status;
@Schema(description = "描述词", example = "一首轻快的歌曲")
private String gptDescriptionPrompt;
@Schema(description = "提示词", example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
private String prompt;
@Schema(description = "模型平台", requiredMode = Schema.RequiredMode.REQUIRED, example = "Suno")
private String platform;
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "chirp-v3.5")
private String model;
@Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "1")
private Integer generateMode;
@Schema(description = "音乐风格标签")
private List<String> tags;
@Schema(description = "音乐时长", example = "[\"pop\",\"jazz\",\"punk\"]")
private Double duration;
@Schema(description = "是否发布", requiredMode = Schema.RequiredMode.REQUIRED, example = "true")
private Boolean publicStatus;
@Schema(description = "任务编号", example = "11369")
private String taskId;
@Schema(description = "错误信息")
private String errorMessage;
@Schema(description = "创建时间", requiredMode = Schema.RequiredMode.REQUIRED)
private LocalDateTime createTime;
}

View File

@ -0,0 +1,22 @@
package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@Schema(description = "管理后台 - AI 音乐修改 Request VO")
@Data
public class AiMusicUpdateReqVO {
@Schema(description = "编号", requiredMode = Schema.RequiredMode.REQUIRED, example = "15583")
@NotNull(message = "编号不能为空")
private Long id;
@Schema(description = "是否发布", example = "true")
private Boolean publicStatus;
// TODO @xin得单独一个 vo因为万一模拟请求就可以改 publicStatus
@Schema(description = "音乐名称", example = "夜空中最亮的星")
private String title;
}

View File

@ -2,6 +2,7 @@ package cn.iocoder.yudao.module.ai.controller.admin.music.vo;
import io.swagger.v3.oas.annotations.media.Schema;
import jakarta.validation.constraints.NotBlank;
import jakarta.validation.constraints.NotEmpty;
import jakarta.validation.constraints.NotNull;
import lombok.Data;
@ -15,24 +16,42 @@ public class AiSunoGenerateReqVO {
@NotBlank(message = "平台不能为空")
private String platform; // 参见 AiPlatformEnum 枚举
/**
* 1. 描述模式描述词 + 是否纯音乐 + 模型
* 2. 歌词模式歌词 + 音乐风格 + 标题 + 模型
*/
@Schema(description = "生成模式", requiredMode = Schema.RequiredMode.REQUIRED, example = "2")
@NotNull(message = "生成模式不能为空")
private Integer generateMode; // 参见 AiMusicGenerateModeEnum 枚举
@Schema(description = "用于生成音乐音频的提示", requiredMode = Schema.RequiredMode.REQUIRED,
example = "创作一首带有轻松吉他旋律的流行歌曲,[verse] 描述夏日海滩的宁静,[chorus] 节奏加快,表达对自由的向往。")
@Schema(description = "用于生成音乐音频的歌词提示",
example = """
1.描述模式创作一首带有轻松吉他旋律的流行歌曲[verse] 描述夏日海滩的宁静[chorus] 节奏加快表达对自由的向往
2.歌词模式
[Verse]
阳光下奔跑 多么欢快
假期就要来 心都飞起来
朋友在一旁 笑声又灿烂
无忧无虑的 每一天甜蜜
[Chorus]
马上放假了 快来庆祝
一起去旅行 快去冒险
日子太短暂 别再等待
马上放假了 梦想起飞
""")
private String prompt;
@Schema(description = "是否纯音乐", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "true")
@Schema(description = "是否纯音乐", example = "true")
private Boolean makeInstrumental;
@Schema(description = "模型版本", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "chirp-v3.5")
private String modelVersion; // 参见 AiModelEnum 枚举
@Schema(description = "模型", requiredMode = Schema.RequiredMode.REQUIRED, example = "chirp-v3.5")
@NotEmpty(message = "模型不能为空")
private String model; // 参见 AiModelEnum 枚举
@Schema(description = "音乐风格", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "[\"pop\",\"jazz\",\"punk\"]")
@Schema(description = "音乐风格", example = "[\"pop\",\"jazz\",\"punk\"]")
private List<String> tags;
@Schema(description = "音乐/歌曲名称", requiredMode = Schema.RequiredMode.NOT_REQUIRED, example = "夜空中最亮的星")
@Schema(description = "音乐/歌曲名称", example = "夜空中最亮的星")
private String title;
}

View File

@ -18,7 +18,7 @@ import java.util.List;
*
* @author xiaoxin
*/
@TableName("ai_music")
@TableName(value = "ai_music", autoResultMap = true)
@Data
public class AiMusicDO extends BaseDO {
@ -30,7 +30,7 @@ public class AiMusicDO extends BaseDO {
/**
* 用户编号
*
* <p>
* 关联 AdminUserDO userId 字段
*/
private Long userId;
@ -67,7 +67,7 @@ public class AiMusicDO extends BaseDO {
/**
* 生成模式
*
* <p>
* 枚举 {@link AiMusicGenerateModeEnum}
*/
private Integer generateMode;
@ -75,11 +75,7 @@ public class AiMusicDO extends BaseDO {
/**
* 描述词
*/
private String gptDescriptionPrompt;
/**
* 提示词
*/
private String prompt;
private String description;
/**
* 平台
@ -98,6 +94,16 @@ public class AiMusicDO extends BaseDO {
@TableField(typeHandler = JacksonTypeHandler.class)
private List<String> tags;
/**
* 音乐时长
*/
private Double duration;
/**
* 是否公开
*/
private Boolean publicStatus;
/**
* 任务编号
*/

View File

@ -1,6 +1,9 @@
package cn.iocoder.yudao.module.ai.dal.mysql.music;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.mybatis.core.mapper.BaseMapperX;
import cn.iocoder.yudao.framework.mybatis.core.query.LambdaQueryWrapperX;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicPageReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import org.apache.ibatis.annotations.Mapper;
@ -18,4 +21,24 @@ public interface AiMusicMapper extends BaseMapperX<AiMusicDO> {
return selectList(AiMusicDO::getStatus, status);
}
default PageResult<AiMusicDO> selectPage(AiMusicPageReqVO reqVO) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiMusicDO>()
.eqIfPresent(AiMusicDO::getUserId, reqVO.getUserId())
.eqIfPresent(AiMusicDO::getTitle, reqVO.getTitle())
.eqIfPresent(AiMusicDO::getStatus, reqVO.getStatus())
.eqIfPresent(AiMusicDO::getGenerateMode, reqVO.getGenerateMode())
.betweenIfPresent(AiMusicDO::getCreateTime, reqVO.getCreateTime())
.eqIfPresent(AiMusicDO::getPublicStatus, reqVO.getPublicStatus())
.orderByDesc(AiMusicDO::getId));
}
default PageResult<AiMusicDO> selectPageByMy(AiMusicPageReqVO reqVO, Long userId) {
return selectPage(reqVO, new LambdaQueryWrapperX<AiMusicDO>()
// 情况一公开
.eq(Boolean.TRUE.equals(reqVO.getPublicStatus()), AiMusicDO::getPublicStatus, reqVO.getPublicStatus())
// 情况二私有
.eq(Boolean.FALSE.equals(reqVO.getPublicStatus()), AiMusicDO::getUserId, userId)
.orderByAsc(AiMusicDO::getId));
}
}

View File

@ -4,10 +4,8 @@ import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.StrUtil;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.common.pojo.CommonResult;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -24,15 +22,17 @@ import cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.ai.service.model.AiChatModelService;
import cn.iocoder.yudao.module.ai.service.model.AiChatRoleService;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatOptions;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.*;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.qianfan.QianFanChatOptions;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import reactor.core.publisher.Flux;
@ -44,8 +44,8 @@ import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionU
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.error;
import static cn.iocoder.yudao.framework.common.pojo.CommonResult.success;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.AI_CHAT_MESSAGE_NOT_EXIST;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_CONVERSATION_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.CHAT_MESSAGE_NOT_EXIST;
/**
* AI 聊天消息 Service 实现类
@ -117,7 +117,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
List<AiChatMessageDO> historyMessages = chatMessageMapper.selectListByConversationId(conversation.getId());
// 1.2 校验模型
AiChatModelDO model = chatModalService.validateChatModel(conversation.getModelId());
StreamingChatClient chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
StreamingChatModel chatClient = apiKeyService.getStreamingChatClient(model.getKeyId());
// 1.3 获取用户头像角色头像
AiChatRoleDO role = conversation.getRoleId() != null ? chatRoleService.getChatRole(conversation.getRoleId()) : null;
@ -150,7 +150,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
log.error("[sendChatMessageStream][userId({}) sendReqVO({}) 发生异常]", userId, sendReqVO, throwable);
chatMessageMapper.updateById(new AiChatMessageDO().setId(assistantMessage.getId()).setContent(throwable.getMessage()));
}).onErrorResume(error -> {
return Flux.just(error(ErrorCodeConstants.AI_CHAT_STREAM_ERROR));
return Flux.just(error(ErrorCodeConstants.CHAT_STREAM_ERROR));
});
}
@ -164,7 +164,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
}
// 1.2 history message 历史消息
List<AiChatMessageDO> contextMessages = filterContextMessages(messages, conversation, sendReqVO);
contextMessages.forEach(message -> chatMessages.add(new ChatMessage(message.getType().toUpperCase(), message.getContent())));
contextMessages.forEach(message -> {
// TODO @芋艿看看有没优化空间
if (MessageType.USER.getValue().equals(message.getType())) {
chatMessages.add(new UserMessage(message.getContent()));
} else {
chatMessages.add(new AssistantMessage(message.getContent()));
}
});
// 1.3 user message 新发送消息
chatMessages.add(new UserMessage(sendReqVO.getContent()));
@ -184,14 +191,14 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
case OLLAMA:
return OllamaOptions.create().withModel(model).withTemperature(temperatureF).withNumPredict(maxTokens);
case YI_YAN:
// TODO @fan增加一个 model
return new YiYanChatOptions().setTemperature(temperatureF).setMaxOutputTokens(maxTokens);
// TODO 芋艿貌似 model 只要一设置就报错
// return QianFanChatOptions.builder().withModel(model).withTemperature(temperatureF).withMaxTokens(maxTokens).build();
return QianFanChatOptions.builder().withTemperature(temperatureF).withMaxTokens(maxTokens).build();
case XING_HUO:
return new XingHuoOptions().setChatModel(XingHuoChatModel.valueOfModel(model)).setTemperature(temperatureF)
.setMaxTokens(maxTokens);
case QIAN_WEN:
// TODO @fan:增加 modeltemperature 参数
return new QianWenOptions().setMaxTokens(maxTokens);
return TongYiChatOptions.builder().withModel(model).withTemperature(temperature).withMaxTokens(maxTokens).build();
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
@ -257,7 +264,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1. 校验消息存在
AiChatMessageDO message = chatMessageMapper.selectById(id);
if (message == null || ObjUtil.notEqual(message.getUserId(), userId)) {
throw exception(AI_CHAT_MESSAGE_NOT_EXIST);
throw exception(CHAT_MESSAGE_NOT_EXIST);
}
// 2. 执行删除
chatMessageMapper.deleteById(id);
@ -268,7 +275,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1. 校验消息存在
List<AiChatMessageDO> messages = chatMessageMapper.selectListByConversationId(conversationId);
if (CollUtil.isEmpty(messages) || ObjUtil.notEqual(messages.get(0).getUserId(), userId)) {
throw exception(AI_CHAT_MESSAGE_NOT_EXIST);
throw exception(CHAT_MESSAGE_NOT_EXIST);
}
// 2. 执行删除
chatMessageMapper.deleteBatchIds(convertList(messages, AiChatMessageDO::getId));
@ -279,7 +286,7 @@ public class AiChatMessageServiceImpl implements AiChatMessageService {
// 1. 校验消息存在
AiChatMessageDO message = chatMessageMapper.selectById(id);
if (message == null) {
throw exception(AI_CHAT_MESSAGE_NOT_EXIST);
throw exception(CHAT_MESSAGE_NOT_EXIST);
}
// 2. 执行删除
chatMessageMapper.deleteById(id);

View File

@ -5,12 +5,14 @@ import cn.iocoder.yudao.framework.common.pojo.PageParam;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
import jakarta.validation.Valid;
import java.util.List;
/**
* AI 绘图 Service 接口
*
@ -35,6 +37,14 @@ public interface AiImageService {
*/
AiImageDO getImage(Long id);
/**
* 获得绘图列表
*
* @param ids 绘图编号数组
* @return 绘图记录列表
*/
List<AiImageDO> getImageList(List<Long> ids);
/**
* 绘制图片
*
@ -61,11 +71,11 @@ public interface AiImageService {
PageResult<AiImageDO> getImagePage(AiImagePageReqVO pageReqVO);
/**
* 更新绘画发布状态
* 更新绘画
*
* @param updateReqVO 更新信息
*/
void updateImagePublicStatus(@Valid AiImageUpdatePublicStatusReqVO updateReqVO);
void updateImage(@Valid AiImageUpdateReqVO updateReqVO);
/**
* 删除绘画

View File

@ -15,7 +15,7 @@ import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageDrawReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImagePageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdatePublicStatusReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.AiImageUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyActionReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.image.vo.midjourney.AiMidjourneyImagineReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.image.AiImageDO;
@ -25,7 +25,7 @@ import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.image.ImageOptions;
import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
@ -35,6 +35,8 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
@ -61,9 +63,6 @@ public class AiImageServiceImpl implements AiImageService {
@Resource
private AiApiKeyService apiKeyService;
@Resource
private MidjourneyApi midjourneyApi;
@Override
public PageResult<AiImageDO> getImagePageMy(Long userId, PageParam pageReqVO) {
return imageMapper.selectPage(userId, pageReqVO);
@ -74,6 +73,14 @@ public class AiImageServiceImpl implements AiImageService {
return imageMapper.selectById(id);
}
@Override
public List<AiImageDO> getImageList(List<Long> ids) {
if (CollUtil.isEmpty(ids)) {
return Collections.emptyList();
}
return imageMapper.selectBatchIds(ids);
}
@Override
public Long drawImage(Long userId, AiImageDrawReqVO drawReqVO) {
// 1. 保存数据库
@ -91,7 +98,7 @@ public class AiImageServiceImpl implements AiImageService {
// 1.1 构建请求
ImageOptions request = buildImageOptions(req);
// 1.2 执行请求
ImageClient imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageModel imageClient = apiKeyService.getImageClient(AiPlatformEnum.validatePlatform(req.getPlatform()));
ImageResponse response = imageClient.call(new ImagePrompt(req.getPrompt(), request));
// 2. 上传到文件服务
@ -117,9 +124,16 @@ public class AiImageServiceImpl implements AiImageService {
.withResponseFormat("b64_json")
.build();
} else if (ObjUtil.equal(draw.getPlatform(), AiPlatformEnum.STABLE_DIFFUSION.getPlatform())) {
// https://platform.stability.ai/docs/api-reference#tag/SDXL-and-SD1.6/operation/textToImage
// https://platform.stability.ai/docs/api-reference#tag/Text-to-Image/operation/textToImage
return StabilityAiImageOptions.builder().withModel(draw.getModel())
.withHeight(draw.getHeight()).withWidth(draw.getWidth()) // TODO @芋艿各种参数
.withHeight(draw.getHeight()).withWidth(draw.getWidth())
.withSeed(Long.valueOf(draw.getOptions().get("seed")))
.withCfgScale(Float.valueOf(draw.getOptions().get("scale")))
.withSteps(Integer.valueOf(draw.getOptions().get("steps")))
.withSampler(String.valueOf(draw.getOptions().get("sampler")))
.withStylePreset(String.valueOf(draw.getOptions().get("stylePreset")))
.withClipGuidancePreset(String.valueOf(draw.getOptions().get("clipGuidancePreset")))
.build();
}
throw new IllegalArgumentException("不支持的 AI 平台:" + draw.getPlatform());
@ -130,7 +144,7 @@ public class AiImageServiceImpl implements AiImageService {
// 1. 校验是否存在
AiImageDO image = validateImageExists(id);
if (ObjUtil.notEqual(image.getUserId(), userId)) {
throw exception(AI_IMAGE_NOT_EXISTS);
throw exception(IMAGE_NOT_EXISTS);
}
// 2. 删除记录
imageMapper.deleteById(id);
@ -142,7 +156,7 @@ public class AiImageServiceImpl implements AiImageService {
}
@Override
public void updateImagePublicStatus(AiImageUpdatePublicStatusReqVO updateReqVO) {
public void updateImage(AiImageUpdateReqVO updateReqVO) {
// 1. 校验存在
validateImageExists(updateReqVO.getId());
// 2. 更新发布状态
@ -160,7 +174,7 @@ public class AiImageServiceImpl implements AiImageService {
private AiImageDO validateImageExists(Long id) {
AiImageDO image = imageMapper.selectById(id);
if (image == null) {
throw exception(AI_IMAGE_NOT_EXISTS);
throw exception(IMAGE_NOT_EXISTS);
}
return image;
}
@ -170,6 +184,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override
@Transactional(rollbackFor = Exception.class)
public Long midjourneyImagine(Long userId, AiMidjourneyImagineReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1. 保存数据库
AiImageDO image = BeanUtils.toBean(reqVO, AiImageDO.class).setUserId(userId).setPublicStatus(false)
.setStatus(AiImageStatusEnum.IN_PROGRESS.getStatus())
@ -177,16 +192,21 @@ public class AiImageServiceImpl implements AiImageService {
imageMapper.insert(image);
// 2. 调用 Midjourney Proxy 提交任务
List<String> base64Array = new ArrayList<>(8);
if (StrUtil.isNotBlank(reqVO.getReferImageUrl())) {
base64Array.add("data:image/jpeg;base64,".concat(Base64.encode(HttpUtil.downloadBytes(reqVO.getReferImageUrl()))));
}
MidjourneyApi.ImagineRequest imagineRequest = new MidjourneyApi.ImagineRequest(
null, reqVO.getPrompt(),null,
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(), reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
base64Array, reqVO.getPrompt(),null,
MidjourneyApi.ImagineRequest.buildState(reqVO.getWidth(),
reqVO.getHeight(), reqVO.getVersion(), reqVO.getModel()));
MidjourneyApi.SubmitResponse imagineResponse = midjourneyApi.imagine(imagineRequest);
// 3. 情况一失败抛出业务异常
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(imagineResponse.code())) {
String description = imagineResponse.description().contains("quota_not_enough") ?
"账户余额不足" : imagineResponse.description();
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
// 4. 情况二成功更新 taskId 和参数
@ -197,6 +217,7 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Integer midjourneySync() {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 获取 Midjourney 平台状态在 进行中 image
List<AiImageDO> imageList = imageMapper.selectListByStatusAndPlatform(
AiImageStatusEnum.IN_PROGRESS.getStatus(), AiPlatformEnum.MIDJOURNEY.getPlatform());
@ -263,16 +284,17 @@ public class AiImageServiceImpl implements AiImageService {
@Override
public Long midjourneyAction(Long userId, AiMidjourneyActionReqVO reqVO) {
MidjourneyApi midjourneyApi = apiKeyService.getMidjourneyApi();
// 1.1 检查 image
AiImageDO image = validateImageExists(reqVO.getId());
if (ObjUtil.notEqual(userId, image.getUserId())) {
throw exception(AI_IMAGE_NOT_EXISTS);
throw exception(IMAGE_NOT_EXISTS);
}
// 1.2 检查 customId
MidjourneyApi.Button button = CollUtil.findOne(image.getButtons(),
buttonX -> buttonX.customId().equals(reqVO.getCustomId()));
if (button == null) {
throw exception(AI_IMAGE_CUSTOM_ID_NOT_EXISTS);
throw exception(IMAGE_CUSTOM_ID_NOT_EXISTS);
}
// 2. 调用 Midjourney Proxy 提交任务
@ -281,7 +303,7 @@ public class AiImageServiceImpl implements AiImageService {
if (!MidjourneyApi.SubmitCodeEnum.SUCCESS_CODES.contains(actionResponse.code())) {
String description = actionResponse.description().contains("quota_not_enough") ?
"账户余额不足" : actionResponse.description();
throw exception(AI_IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
throw exception(IMAGE_MIDJOURNEY_SUBMIT_FAIL, description);
}
// 3. 新增 image 记录

View File

@ -1,58 +0,0 @@
package cn.iocoder.yudao.module.ai.service.image;
import lombok.Data;
import org.springframework.ai.image.ImageOptions;
/**
* @author fansili
* @time 2024/6/5 10:34
* @since 1.0
*/
@Data
public class MidjourneyImageOptions implements ImageOptions {
/**
* 模型
*/
private String model;
/**
* 宽度
*/
private Integer width;
/**
* 高度
*/
private Integer height;
/**
* 版本
*/
private String version;
/**
* 参数
*/
private String state;
@Override
public Integer getN() {
return 0;
}
@Override
public String getModel() {
return model;
}
@Override
public Integer getWidth() {
return width;
}
@Override
public Integer getHeight() {
return height;
}
@Override
public String getResponseFormat() {
return "";
}
}

View File

@ -1,13 +1,15 @@
package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeyPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import jakarta.validation.Valid;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import java.util.List;
@ -79,7 +81,7 @@ public interface AiApiKeyService {
* @param id 编号
* @return StreamingChatClient 对象
*/
StreamingChatClient getStreamingChatClient(Long id);
StreamingChatModel getStreamingChatClient(Long id);
/**
* 获得 ImageClient 对象
@ -89,6 +91,24 @@ public interface AiApiKeyService {
* @param platform 平台
* @return ImageClient 对象
*/
ImageClient getImageClient(AiPlatformEnum platform);
ImageModel getImageClient(AiPlatformEnum platform);
/**
* 获得 MidjourneyApi 对象
*
* TODO 可优化点目前默认获取 Midjourney 对应的第一个开启的配置用于绘画后续可以支持配置选择
*
* @return MidjourneyApi 对象
*/
MidjourneyApi getMidjourneyApi();
/**
* 获得 SunoApi 对象
*
* TODO 可优化点目前默认获取 Suno 对应的第一个开启的配置用于音乐后续可以支持配置选择
*
* @return SunoApi 对象
*/
SunoApi getSunoApi();
}

View File

@ -2,6 +2,8 @@ package cn.iocoder.yudao.module.ai.service.model;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.enums.CommonStatusEnum;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.framework.common.util.object.BeanUtils;
@ -10,8 +12,8 @@ import cn.iocoder.yudao.module.ai.controller.admin.model.vo.apikey.AiApiKeySaveR
import cn.iocoder.yudao.module.ai.dal.dataobject.model.AiApiKeyDO;
import cn.iocoder.yudao.module.ai.dal.mysql.model.AiApiKeyMapper;
import jakarta.annotation.Resource;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.stereotype.Service;
import org.springframework.validation.annotation.Validated;
@ -96,14 +98,14 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
// ========== spring-ai 集成 ==========
@Override
public StreamingChatClient getStreamingChatClient(Long id) {
public StreamingChatModel getStreamingChatClient(Long id) {
AiApiKeyDO apiKey = validateApiKey(id);
AiPlatformEnum platform = AiPlatformEnum.validatePlatform(apiKey.getPlatform());
return clientFactory.getOrCreateStreamingChatClient(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public ImageClient getImageClient(AiPlatformEnum platform) {
public ImageModel getImageClient(AiPlatformEnum platform) {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(platform.getName(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
return null;
@ -111,4 +113,25 @@ public class AiApiKeyServiceImpl implements AiApiKeyService {
return clientFactory.getOrCreateImageClient(platform, apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public MidjourneyApi getMidjourneyApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.MIDJOURNEY.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
// todo @芋艿 这些地方直接抛异常会好点不然调用到的地方都需要做判断
if (apiKey == null) {
return null;
}
return clientFactory.getOrCreateMidjourneyApi(apiKey.getApiKey(), apiKey.getUrl());
}
@Override
public SunoApi getSunoApi() {
AiApiKeyDO apiKey = apiKeyMapper.selectFirstByPlatformAndStatus(
AiPlatformEnum.SUNO.getPlatform(), CommonStatusEnum.ENABLE.getStatus());
if (apiKey == null) {
return null;
}
return clientFactory.getOrCreateSunoApi(apiKey.getApiKey(), apiKey.getUrl());
}
}

View File

@ -1,6 +1,9 @@
package cn.iocoder.yudao.module.ai.service.music;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.*;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import jakarta.validation.Valid;
import java.util.List;
@ -15,7 +18,7 @@ public interface AiMusicService {
* 音乐生成
*
* @param userId 用户编号
* @param reqVO 请求参数
* @param reqVO 请求参数
* @return 生成的音乐ID
*/
List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO);
@ -27,4 +30,58 @@ public interface AiMusicService {
*/
Integer syncMusic();
/**
* 更新音乐发布状态
*
* @param updateReqVO 更新信息
*/
void updateMusic(@Valid AiMusicUpdateReqVO updateReqVO);
/**
* 更新我的音乐
*
* @param updateReqVO 更新信息
*/
void updateMyMusic(@Valid AiMusicUpdateReqVO updateReqVO, Long userId);
/**
* 删除AI 音乐
*
* @param id 编号
*/
void deleteMusic(Long id);
/**
* 删除我的音乐记录
*
* @param id 音乐编号
* @param userId 用户编号
*/
void deleteMusicMy(Long id, Long userId);
/**
* 获得AI 音乐
*
* @param id 音乐编号
* @return 音乐内容
*/
AiMusicDO getMusic(Long id);
/**
* 获得音乐分页
*
* @param pageReqVO 分页查询
* @return 音乐分页
*/
PageResult<AiMusicDO> getMusicPage(AiMusicPageReqVO pageReqVO);
/**
* 获得我的音乐分页
*
* @param pageReqVO 分页查询
* @param userId 用户编号
* @return 音乐分页
*/
PageResult<AiMusicDO> getMusicMyPage(AiMusicPageReqVO pageReqVO, Long userId);
}

View File

@ -2,21 +2,32 @@ package cn.iocoder.yudao.module.ai.service.music;
import cn.hutool.core.collection.CollUtil;
import cn.hutool.core.text.StrPool;
import cn.hutool.core.util.ObjUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import cn.hutool.http.HttpUtil;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.common.pojo.PageResult;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicPageReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiMusicUpdateReqVO;
import cn.iocoder.yudao.module.ai.controller.admin.music.vo.AiSunoGenerateReqVO;
import cn.iocoder.yudao.module.ai.dal.dataobject.music.AiMusicDO;
import cn.iocoder.yudao.module.ai.dal.mysql.music.AiMusicMapper;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicGenerateModeEnum;
import cn.iocoder.yudao.module.ai.enums.music.AiMusicStatusEnum;
import cn.iocoder.yudao.module.ai.service.model.AiApiKeyService;
import cn.iocoder.yudao.module.infra.api.file.FileApi;
import jakarta.annotation.Resource;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
import java.util.*;
import static cn.iocoder.yudao.framework.common.exception.util.ServiceExceptionUtil.exception;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertList;
import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.convertMap;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.IMAGE_NOT_EXISTS;
import static cn.iocoder.yudao.module.ai.enums.ErrorCodeConstants.MUSIC_NOT_EXISTS;
/**
* AI 音乐 Service 实现类
@ -28,25 +39,30 @@ import static cn.iocoder.yudao.framework.common.util.collection.CollectionUtils.
public class AiMusicServiceImpl implements AiMusicService {
@Resource
private SunoApi sunoApi;
private AiApiKeyService apiKeyService;
@Resource
private AiMusicMapper musicMapper;
@Resource
private FileApi fileApi;
@Override
public List<Long> generateMusic(Long userId, AiSunoGenerateReqVO reqVO) {
// 1. 调用 Suno 生成音乐
SunoApi sunoApi = apiKeyService.getSunoApi();
// TODO 芋艿这两个貌似一直没跑成功你那可以么用的请求是 AiMusicController.http --xin大部分ok的补充了error_message
List<SunoApi.MusicData> musicDataList;
if (Objects.equals(AiMusicGenerateModeEnum.LYRIC.getMode(), reqVO.getGenerateMode())) {
// 1.1 歌词模式
if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.1 描述模式
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
musicDataList = sunoApi.customGenerate(generateRequest);
} else if (Objects.equals(AiMusicGenerateModeEnum.DESCRIPTION.getMode(), reqVO.getGenerateMode())) {
// 1.2 描述模式
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModelVersion(), reqVO.getMakeInstrumental());
reqVO.getPrompt(), reqVO.getModel(), reqVO.getMakeInstrumental());
musicDataList = sunoApi.generate(generateRequest);
} else if (Objects.equals(AiMusicGenerateModeEnum.LYRIC.getMode(), reqVO.getGenerateMode())) {
// 1.2 歌词模式
SunoApi.MusicGenerateRequest generateRequest = new SunoApi.MusicGenerateRequest(
reqVO.getPrompt(), reqVO.getModel(), CollUtil.join(reqVO.getTags(), StrPool.COMMA), reqVO.getTitle());
musicDataList = sunoApi.customGenerate(generateRequest);
} else {
throw new IllegalArgumentException(StrUtil.format("未知生成模式({})", reqVO));
}
@ -56,7 +72,7 @@ public class AiMusicServiceImpl implements AiMusicService {
return Collections.emptyList();
}
List<AiMusicDO> musicList = buildMusicDOList(musicDataList);
musicList.forEach(music -> music.setUserId(userId).setPlatform(music.getPlatform()).setGenerateMode(reqVO.getGenerateMode()));
musicList.forEach(music -> music.setUserId(userId).setPlatform(reqVO.getPlatform()).setGenerateMode(reqVO.getGenerateMode()));
musicMapper.insertBatch(musicList);
return convertList(musicList, AiMusicDO::getId);
}
@ -70,6 +86,7 @@ public class AiMusicServiceImpl implements AiMusicService {
log.info("[syncMusic][Suno 开始同步, 共 ({}) 个任务]", streamingTask.size());
// GET 请求为避免参数过长分批次处理
SunoApi sunoApi = apiKeyService.getSunoApi();
CollUtil.split(streamingTask, 36).forEach(chunkList -> {
Map<String, Long> taskIdMap = convertMap(chunkList, AiMusicDO::getTaskId, AiMusicDO::getId);
List<SunoApi.MusicData> musicTaskList = sunoApi.getMusicList(new ArrayList<>(taskIdMap.keySet()));
@ -85,19 +102,120 @@ public class AiMusicServiceImpl implements AiMusicService {
return streamingTask.size();
}
@Override
public void updateMusic(AiMusicUpdateReqVO updateReqVO) {
// 校验存在
validateMusicExists(updateReqVO.getId());
// 更新
musicMapper.updateById(new AiMusicDO().setId(updateReqVO.getId()).setPublicStatus(updateReqVO.getPublicStatus()));
}
@Override
public void updateMyMusic(AiMusicUpdateReqVO updateReqVO, Long userId) {
// 校验音乐是否存在
AiMusicDO musicDO = validateMusicExists(updateReqVO.getId());
if (ObjUtil.notEqual(musicDO.getUserId(), userId)) {
throw exception(MUSIC_NOT_EXISTS);
}
// 更新
musicMapper.updateById(new AiMusicDO().setId(updateReqVO.getId()).setTitle(updateReqVO.getTitle()));
}
@Override
public void deleteMusic(Long id) {
// 校验存在
validateMusicExists(id);
// 删除
musicMapper.deleteById(id);
}
@Override
public void deleteMusicMy(Long id, Long userId) {
// 1. 校验是否存在
AiMusicDO music = validateMusicExists(id);
if (ObjUtil.notEqual(music.getUserId(), userId)) {
throw exception(IMAGE_NOT_EXISTS);
}
// 2. 删除记录
musicMapper.deleteById(id);
}
@Override
public AiMusicDO getMusic(Long id) {
return musicMapper.selectById(id);
}
@Override
public PageResult<AiMusicDO> getMusicPage(AiMusicPageReqVO pageReqVO) {
return musicMapper.selectPage(pageReqVO);
}
@Override
public PageResult<AiMusicDO> getMusicMyPage(AiMusicPageReqVO pageReqVO, Long userId) {
return musicMapper.selectPageByMy(pageReqVO, userId);
}
/**
* 构建 AiMusicDO 集合
*
* @param musicList suno 音乐任务列表
* @return AiMusicDO 集合
*/
private static List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
return convertList(musicList, musicData -> new AiMusicDO()
.setTaskId(musicData.id()).setModel(musicData.modelName())
.setPrompt(musicData.prompt()).setGptDescriptionPrompt(musicData.gptDescriptionPrompt())
.setAudioUrl(musicData.audioUrl()).setVideoUrl(musicData.videoUrl()).setImageUrl(musicData.imageUrl())
.setTitle(musicData.title()).setLyric(musicData.lyric()).setTags(StrUtil.split(musicData.tags(), StrPool.COMMA))
.setStatus(Objects.equals("complete", musicData.status()) ? AiMusicStatusEnum.SUCCESS.getStatus() : AiMusicStatusEnum.IN_PROGRESS.getStatus()));
private List<AiMusicDO> buildMusicDOList(List<SunoApi.MusicData> musicList) {
return convertList(musicList, musicData -> {
Integer status;
if (Objects.equals("complete", musicData.status())) {
status = AiMusicStatusEnum.SUCCESS.getStatus();
} else if (Objects.equals("error", musicData.status())) {
status = AiMusicStatusEnum.FAIL.getStatus();
} else {
status = AiMusicStatusEnum.IN_PROGRESS.getStatus();
}
return new AiMusicDO()
.setTaskId(musicData.id()).setModel(musicData.modelName())
.setDescription(musicData.gptDescriptionPrompt())
.setAudioUrl(downloadFile(status, musicData.audioUrl()))
.setVideoUrl(downloadFile(status, musicData.videoUrl()))
.setImageUrl(downloadFile(status, musicData.imageUrl()))
.setTitle(musicData.title()).setDuration(musicData.duration())
.setLyric(musicData.lyric()).setTags(StrUtil.split(musicData.tags(), StrPool.COMMA))
.setErrorMessage(musicData.errorMessage())
.setStatus(status);
});
}
/**
* 音乐生成好后将音频文件上传到文件服务器
*
* @param status 音乐状态
* @param url 音频文件地址
* @return 内部文件地址
*/
private String downloadFile(Integer status, String url) {
if (StrUtil.isBlank(url) || ObjectUtil.notEqual(status, AiMusicStatusEnum.SUCCESS.getStatus())) {
return url;
}
try {
byte[] bytes = HttpUtil.downloadBytes(url);
return fileApi.createFile(bytes);
} catch (Exception e) {
log.error("[downloadFile][url({}) 下载失败]", url, e);
return url;
}
}
/**
* 校验音乐是否存在
*
* @param id 音乐编号
* @return 音乐信息
*/
private AiMusicDO validateMusicExists(Long id) {
AiMusicDO music = musicMapper.selectById(id);
if (music == null) {
throw exception(MUSIC_NOT_EXISTS);
}
return music;
}
}

View File

@ -2,33 +2,38 @@
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>cn.iocoder.boot</groupId>
<artifactId>yudao-module-ai</artifactId>
<version>${revision}</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<artifactId>yudao-spring-boot-starter-ai</artifactId>
<name>${project.artifactId}</name>
<description>AI 大模型拓展,接入国内外大模型</description>
<properties>
<spring-ai.version>1.0.0-M1</spring-ai.version>
</properties>
<dependencies>
<dependency>
<groupId>io.springboot.ai</groupId>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-ollama-spring-boot-starter</artifactId>
<version>1.0.3</version>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<version>1.0.3</version>
<version>${spring-ai.version}</version>
</dependency>
<dependency>
<groupId>io.springboot.ai</groupId>
<artifactId>spring-ai-stability-ai</artifactId>
<version>1.0.3</version>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-stability-ai-spring-boot-starter</artifactId>
<version>${spring-ai.version}</version>
</dependency>
<!-- <dependency>-->
<!-- <groupId>io.springboot.ai</groupId>-->
<!-- <groupId>org.springframework.ai</groupId>-->
<!-- <artifactId>spring-ai-vertex-ai-gemini</artifactId>-->
<!-- <version>1.0.3</version>-->
<!-- </dependency>-->
@ -38,11 +43,19 @@
<artifactId>yudao-common</artifactId>
</dependency>
<!-- TODO 芋艿:等 spring-ai 官方发布后,需要把 groupId 改下 -->
<dependency>
<groupId>group.springframework.ai</groupId>
<artifactId>spring-ai-qianfan-spring-boot-starter</artifactId>
<version>1.1.0</version>
</dependency>
<!-- 阿里云 通义千问 -->
<!-- TODO 芋艿:等 spring cloud alibaba ai 发布最新的时候,可以替换掉这个依赖,并且删除我们直接 cv 的代码 -->
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>dashscope-sdk-java</artifactId>
<version>2.11.0</version>
<version>2.14.0</version>
</dependency>
<dependency>

View File

@ -4,21 +4,16 @@ import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactory;
import cn.iocoder.yudao.framework.ai.core.factory.AiClientFactoryImpl;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import lombok.extern.slf4j.Slf4j;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Import;
/**
* 芋道 AI 自动配置
@ -28,6 +23,7 @@ import org.springframework.context.annotation.Bean;
@AutoConfiguration
@EnableConfigurationProperties(YudaoAiProperties.class)
@Slf4j
@Import(TongYiAutoConfiguration.class)
public class YudaoAiAutoConfiguration {
@Bean
@ -57,48 +53,6 @@ public class YudaoAiAutoConfiguration {
);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.qianwen.enable", havingValue = "true")
public QianWenChatClient qianWenChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.QianWenProperties qianWenProperties = yudaoAiProperties.getQianwen();
// 转换配置
QianWenOptions qianWenOptions = new QianWenOptions();
// qianWenOptions.setModel(qianWenProperties.getModel().getModel()); TODO @fan这里报错了
qianWenOptions.setTemperature(qianWenProperties.getTemperature());
// qianWenOptions.setTopK(qianWenProperties.getTopK()); TODO 芋艿后续弄
qianWenOptions.setTopP(qianWenProperties.getTopP());
qianWenOptions.setMaxTokens(qianWenProperties.getMaxTokens());
// qianWenOptions.setTemperature(qianWenProperties.getTemperature()); TODO 芋艿后续弄
return new QianWenChatClient(
new QianWenApi(
qianWenProperties.getApiKey(),
QianWenChatModal.QWEN_72B_CHAT
),
qianWenOptions
);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.yiyan.enable", havingValue = "true")
public YiYanChatClient yiYanChatClient(YudaoAiProperties yudaoAiProperties) {
YudaoAiProperties.YiYanProperties yiYanProperties = yudaoAiProperties.getYiyan();
// 转换配置
YiYanChatOptions yiYanOptions = new YiYanChatOptions();
// yiYanOptions.setTopK(yiYanProperties.getTopK()); TODO 芋艿后续弄
yiYanOptions.setTopP(yiYanProperties.getTopP());
yiYanOptions.setTemperature(yiYanProperties.getTemperature());
yiYanOptions.setMaxOutputTokens(yiYanProperties.getMaxTokens());
return new YiYanChatClient(
new YiYanApi(
yiYanProperties.getAppKey(),
yiYanProperties.getSecretKey(),
yiYanProperties.getModel(),
yiYanProperties.getRefreshTokenSecondTime()
),
yiYanOptions
);
}
@Bean
@ConditionalOnProperty(value = "yudao.ai.midjourney.enable", havingValue = "true")
public MidjourneyApi midjourneyApi(YudaoAiProperties yudaoAiProperties) {

View File

@ -2,7 +2,6 @@ package cn.iocoder.yudao.framework.ai.config;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import lombok.Data;
import lombok.experimental.Accessors;
import org.springframework.ai.autoconfigure.openai.OpenAiImageProperties;
@ -20,9 +19,7 @@ import org.springframework.boot.context.properties.ConfigurationProperties;
@ConfigurationProperties(prefix = "yudao.ai")
public class YudaoAiProperties {
private QianWenProperties qianwen;
private XingHuoProperties xinghuo;
private YiYanProperties yiyan;
private OpenAiImageProperties openAiImage;
private MidjourneyProperties midjourney;
private SunoProperties suno;
@ -47,22 +44,6 @@ public class YudaoAiProperties {
}
@Data
@Accessors(chain = true)
public static class QianWenProperties extends ChatProperties {
/**
* api key
*/
private String apiKey;
/**
* model
*/
private YiYanChatModel model;
}
@Data
@Accessors(chain = true)
public static class XingHuoProperties extends ChatProperties {
private String appId;
@ -72,28 +53,6 @@ public class YudaoAiProperties {
}
@Data
@Accessors(chain = true)
public static class YiYanProperties extends ChatProperties {
/**
* appKey
*/
private String appKey;
/**
* secretKey
*/
private String secretKey;
/**
* 模型
*/
private YiYanChatModel model = YiYanChatModel.ERNIE4_3_5_8K;
/**
* token 刷新时间(默认 86400 = 24小时)
*/
private int refreshTokenSecondTime = 86400;
}
@Data
public static class MidjourneyProperties {

View File

@ -1,8 +1,10 @@
package cn.iocoder.yudao.framework.ai.core.factory;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
/**
* AI 客户端工厂的接口类
@ -21,7 +23,7 @@ public interface AiClientFactory {
* @param url API URL
* @return StreamingChatClient 对象
*/
StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于默认配置获得 StreamingChatClient 对象
@ -31,7 +33,7 @@ public interface AiClientFactory {
* @param platform 平台
* @return StreamingChatClient 对象
*/
StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform);
StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform);
/**
* 基于默认配置获得 ImageClient 对象
@ -41,7 +43,7 @@ public interface AiClientFactory {
* @param platform 平台
* @return ImageClient 对象
*/
ImageClient getDefaultImageClient(AiPlatformEnum platform);
ImageModel getDefaultImageClient(AiPlatformEnum platform);
/**
* 基于指定配置获得 ImageClient 对象
@ -53,6 +55,28 @@ public interface AiClientFactory {
* @param url API URL
* @return ImageClient 对象
*/
ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url);
/**
* 基于指定配置获得 MidjourneyApi 对象
*
* 如果不存在则进行创建
*
* @param apiKey API KEY
* @param url API URL
* @return MidjourneyApi 对象
*/
MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url);
/**
* 基于指定配置获得 SunoApi 对象
*
* 如果不存在则进行创建
*
* @param apiKey API KEY
* @param url API URL
* @return SunoApi 对象
*/
SunoApi getOrCreateSunoApi(String apiKey, String url);
}

View File

@ -9,26 +9,35 @@ import cn.hutool.extra.spring.SpringUtil;
import cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration;
import cn.iocoder.yudao.framework.ai.config.YudaoAiProperties;
import cn.iocoder.yudao.framework.ai.core.enums.AiPlatformEnum;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import cn.iocoder.yudao.framework.ai.core.model.midjourney.api.MidjourneyApi;
import cn.iocoder.yudao.framework.ai.core.model.suno.api.SunoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import com.alibaba.cloud.ai.tongyi.TongYiAutoConfiguration;
import com.alibaba.cloud.ai.tongyi.TongYiConnectionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import org.springframework.ai.autoconfigure.ollama.OllamaAutoConfiguration;
import org.springframework.ai.autoconfigure.openai.OpenAiAutoConfiguration;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.image.ImageClient;
import org.springframework.ai.ollama.OllamaChatClient;
import org.springframework.ai.autoconfigure.qianfan.QianFanAutoConfiguration;
import org.springframework.ai.autoconfigure.qianfan.QianFanChatProperties;
import org.springframework.ai.autoconfigure.qianfan.QianFanConnectionProperties;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.image.ImageModel;
import org.springframework.ai.ollama.OllamaChatModel;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.openai.OpenAiChatClient;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.ApiUtils;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiImageApi;
import org.springframework.ai.stabilityai.StabilityAiImageClient;
import org.springframework.ai.qianfan.QianFanChatModel;
import org.springframework.ai.qianfan.api.QianFanApi;
import org.springframework.ai.stabilityai.StabilityAiImageModel;
import org.springframework.ai.stabilityai.api.StabilityAiApi;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import java.util.List;
@ -41,9 +50,9 @@ import java.util.List;
public class AiClientFactoryImpl implements AiClientFactory {
@Override
public StreamingChatClient getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatClient.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatClient>) () -> {
public StreamingChatModel getOrCreateStreamingChatClient(AiPlatformEnum platform, String apiKey, String url) {
String cacheKey = buildClientCacheKey(StreamingChatModel.class, platform, apiKey, url);
return Singleton.get(cacheKey, (Func0<StreamingChatModel>) () -> {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
@ -65,39 +74,39 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
@Override
public StreamingChatClient getDefaultStreamingChatClient(AiPlatformEnum platform) {
public StreamingChatModel getDefaultStreamingChatClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiChatClient.class);
return SpringUtil.getBean(OpenAiChatModel.class);
case OLLAMA:
return SpringUtil.getBean(OllamaChatClient.class);
return SpringUtil.getBean(OllamaChatModel.class);
case YI_YAN:
return SpringUtil.getBean(YiYanChatClient.class);
return SpringUtil.getBean(QianFanChatModel.class);
case XING_HUO:
return SpringUtil.getBean(XingHuoChatClient.class);
case QIAN_WEN:
return SpringUtil.getBean(QianWenChatClient.class);
return SpringUtil.getBean(TongYiChatModel.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
@Override
public ImageClient getDefaultImageClient(AiPlatformEnum platform) {
public ImageModel getDefaultImageClient(AiPlatformEnum platform) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
return SpringUtil.getBean(OpenAiImageClient.class);
return SpringUtil.getBean(OpenAiImageModel.class);
case STABLE_DIFFUSION:
return SpringUtil.getBean(StabilityAiImageClient.class);
return SpringUtil.getBean(StabilityAiImageModel.class);
default:
throw new IllegalArgumentException(StrUtil.format("未知平台({})", platform));
}
}
@Override
public ImageClient getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
public ImageModel getOrCreateImageClient(AiPlatformEnum platform, String apiKey, String url) {
//noinspection EnhancedSwitchMigration
switch (platform) {
case OPENAI:
@ -109,6 +118,21 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
}
@Override
public MidjourneyApi getOrCreateMidjourneyApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(MidjourneyApi.class, AiPlatformEnum.MIDJOURNEY.getPlatform(), apiKey, url);
return Singleton.get(cacheKey, (Func0<MidjourneyApi>) () -> {
YudaoAiProperties.MidjourneyProperties properties = SpringUtil.getBean(YudaoAiProperties.class).getMidjourney();
return new MidjourneyApi(url, apiKey, properties.getNotifyUrl());
});
}
@Override
public SunoApi getOrCreateSunoApi(String apiKey, String url) {
String cacheKey = buildClientCacheKey(SunoApi.class, AiPlatformEnum.SUNO.getPlatform(), apiKey, url);
return Singleton.get(cacheKey, (Func0<SunoApi>) () -> new SunoApi(url));
}
private static String buildClientCacheKey(Class<?> clazz, Object... params) {
if (ArrayUtil.isEmpty(params)) {
return clazz.getName();
@ -121,30 +145,31 @@ public class AiClientFactoryImpl implements AiClientFactory {
/**
* 可参考 {@link OpenAiAutoConfiguration}
*/
private static OpenAiChatClient buildOpenAiChatClient(String openAiToken, String url) {
private static OpenAiChatModel buildOpenAiChatClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiApi openAiApi = new OpenAiApi(url, openAiToken);
return new OpenAiChatClient(openAiApi);
return new OpenAiChatModel(openAiApi);
}
/**
* 可参考 {@link OllamaAutoConfiguration}
*/
private static OllamaChatClient buildOllamaChatClient(String url) {
private static OllamaChatModel buildOllamaChatClient(String url) {
OllamaApi ollamaApi = new OllamaApi(url);
return new OllamaChatClient(ollamaApi);
return new OllamaChatModel(ollamaApi);
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#yiYanChatClient(YudaoAiProperties)}
* 可参考 {@link QianFanAutoConfiguration#qianFanChatModel(QianFanConnectionProperties, QianFanChatProperties, RestClient.Builder, RetryTemplate, ResponseErrorHandler)}
*/
private static YiYanChatClient buildYiYanChatClient(String key) {
private static QianFanChatModel buildYiYanChatClient(String key) {
// TODO 芋艿貌似目前设置request 势必会报错
List<String> keys = StrUtil.split(key, '|');
Assert.equals(keys.size(), 2, "YiYanChatClient 的密钥需要 (appKey|secretKey) 格式");
String appKey = keys.get(0);
String secretKey = keys.get(1);
YiYanApi yiYanApi = new YiYanApi(appKey, secretKey, YiYanApi.DEFAULT_CHAT_MODEL, 0);
return new YiYanChatClient(yiYanApi);
QianFanApi qianFanApi = new QianFanApi(appKey, secretKey);
return new QianFanChatModel(qianFanApi);
}
/**
@ -161,11 +186,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
}
/**
* 可参考 {@link YudaoAiAutoConfiguration#qianWenChatClient(YudaoAiProperties)}
* 可参考 {@link TongYiAutoConfiguration#tongYiChatClient(Generation, TongYiChatProperties, TongYiConnectionProperties)}
*/
private static QianWenChatClient buildQianWenChatClient(String key) {
QianWenApi qianWenApi = new QianWenApi(key, QianWenChatModal.QWEN_72B_CHAT);
return new QianWenChatClient(qianWenApi);
private static TongYiChatModel buildQianWenChatClient(String key) {
com.alibaba.dashscope.aigc.generation.Generation generation = SpringUtil.getBean(Generation.class);
TongYiChatProperties chatOptions = SpringUtil.getBean(TongYiChatProperties.class);
// TODO @芋艿貌似 apiKey 是全局唯一的得测试下
// TODO @芋艿貌似阿里云不是增量返回的
TongYiConnectionProperties connectionProperties = new TongYiConnectionProperties();
connectionProperties.setApiKey(key);
return new TongYiAutoConfiguration().tongYiChatClient(generation, chatOptions, connectionProperties);
}
// private static VertexAiGeminiChatClient buildGoogleGemir(String key) {
@ -175,16 +205,16 @@ public class AiClientFactoryImpl implements AiClientFactory {
// return new VertexAiGeminiChatClient(vertexApi);
// }
private ImageClient buildOpenAiImageClient(String openAiToken, String url) {
private OpenAiImageModel buildOpenAiImageClient(String openAiToken, String url) {
url = StrUtil.blankToDefault(url, ApiUtils.DEFAULT_BASE_URL);
OpenAiImageApi openAiApi = new OpenAiImageApi(url, openAiToken, RestClient.builder());
return new OpenAiImageClient(openAiApi);
return new OpenAiImageModel(openAiApi);
}
private ImageClient buildStabilityAiImageClient(String apiKey, String url) {
private StabilityAiImageModel buildStabilityAiImageClient(String apiKey, String url) {
url = StrUtil.blankToDefault(url, StabilityAiApi.DEFAULT_BASE_URL);
StabilityAiApi stabilityAiApi = new StabilityAiApi(apiKey, StabilityAiApi.DEFAULT_IMAGE_MODEL, url);
return new StabilityAiImageClient(stabilityAiApi);
return new StabilityAiImageModel(stabilityAiApi);
}
}

View File

@ -117,7 +117,7 @@ public class SunoApi {
* @param prompt 用于生成音乐音频的提示
* @param tags 音乐风格
* @param title 音乐名称
* @param mv 模型
* @param model 模型
* @param waitAudio false 表示后台模式仅返回音频任务信息需要调用 get API 获取详细的音频信息
* true 表示同步模式API 最多等待 100s音频生成完毕后直接返回音频链接等信息建议在 GPT agent 中使用
* @param makeInstrumental 指示音乐音频是否为定制如果为 true则从歌词生成否则从提示生成
@ -127,7 +127,7 @@ public class SunoApi {
String prompt,
String tags,
String title,
String mv,
String model,
@JsonProperty("wait_audio") boolean waitAudio,
@JsonProperty("make_instrumental") boolean makeInstrumental
) {
@ -136,12 +136,12 @@ public class SunoApi {
this(prompt, null, null, null, false, false);
}
public MusicGenerateRequest(String prompt, String mv, boolean makeInstrumental) {
this(prompt, null, null, mv, false, makeInstrumental);
public MusicGenerateRequest(String prompt, String model, boolean makeInstrumental) {
this(prompt, null, null, model, false, makeInstrumental);
}
public MusicGenerateRequest(String prompt, String mv, String tags, String title) {
this(prompt, tags, title, mv, false, false);
public MusicGenerateRequest(String prompt, String model, String tags, String title) {
this(prompt, tags, title, model, false, false);
}
}
@ -162,6 +162,7 @@ public class SunoApi {
* @param prompt 生成音乐音频的提示
* @param type 操作类型
* @param tags 音乐类型标签
* @param duration 音乐时长
*/
public record MusicData(
String id,
@ -174,9 +175,11 @@ public class SunoApi {
@JsonProperty("model_name") String modelName,
String status,
@JsonProperty("gpt_description_prompt") String gptDescriptionPrompt,
@JsonProperty("error_message") String errorMessage,
String prompt,
String type,
String tags
String tags,
Double duration
) {
}

View File

@ -1,158 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import cn.hutool.core.util.NumberUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import org.springframework.ai.chat.*;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.google.common.collect.Lists;
import io.reactivex.Flowable;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
// TODO @芋艿暂时不需要重构 spring cloud alibaba ai 发布最新的
/**
* 阿里 通义千问 client
* <p>
* 文档地址https://help.aliyun.com/document_detail/2587494.html?spm=a2c4g.2587492.0.0.53f33c566sXskp
* <p>
* author: fansili
* time: 2024/3/13 21:06
*/
@Slf4j
public class QianWenChatClient implements ChatClient, StreamingChatClient {
private QianWenApi qianWenApi;
private QianWenOptions qianWenOptions;
public QianWenChatClient() {
}
public QianWenChatClient(QianWenApi qianWenApi) {
this.qianWenApi = qianWenApi;
}
public QianWenChatClient(QianWenApi qianWenApi, QianWenOptions qianWenOptions) {
this.qianWenApi = qianWenApi;
this.qianWenOptions = qianWenOptions;
}
// TODO @fansili看看咋公用出来允许传入类似异常之类的参数
public final RetryTemplate retryTemplate = RetryTemplate.builder()
// 最大重试次数 10
.maxAttempts(10)
.retryOn(YiYanApiException.class)
// 最大重试5次第一次间隔3000ms第二次3000ms * 2第三次3000ms * 3以此类推最大间隔3 * 60000ms
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
@Override
public <T extends Object, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
}
})
.build();
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
QwenParam request = this.createRequest(prompt, false);
// 调用 callWithFunctionSupport 发送请求
ResponseEntity<GenerationResult> responseEntity = qianWenApi.chatCompletionEntity(request);
// 获取结果封装 chatCompletion
GenerationResult response = responseEntity.getBody();
// if (!response.isSuccess()) {
// return new ChatResponse(List.of(new Generation(String.format("failed to create completion, requestId: %s, code: %s, message: %s\n",
// response.getRequestId(), response.getCode(), response.getMessage()))));
// }
// 转换为 Generation 返回
return new ChatResponse(response.getOutput().getChoices().stream()
.map(choices -> new Generation(choices.getMessage().getContent()))
.collect(Collectors.toList()));
});
}
private QwenParam createRequest(Prompt prompt, boolean stream) {
// 获取 ChatOptions
QianWenOptions chatOptions = getChatOptions(prompt);
//
List<Message> messageList = Lists.newArrayList();
prompt.getInstructions().stream().forEach(instruction -> {
Message message = new Message();
message.setRole(instruction.getMessageType().getValue());
message.setContent(instruction.getContent());
messageList.add(message);
});
return QwenParam.builder()
.model(qianWenApi.getQianWenChatModal().getModel())
.prompt(prompt.getContents())
.messages(messageList)
.maxTokens(chatOptions.getMaxTokens())
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.topP(chatOptions.getTopP() == null ? null : Double.valueOf(chatOptions.getTopP()))
.topK(chatOptions.getTopK())
.temperature(chatOptions.getTemperature())
// 控制流式输出模式即后面的内容会包含已经输出的内容设置为True将开启增量输出模式后面的输出不会包含已经输出的内容您需要自行拼接整体输出
.incrementalOutput(true)
/* set the random seed, optional, default to 1234 if not set */
.seed(100)
.apiKey(qianWenApi.getApiKey())
.build();
}
private @NotNull QianWenOptions getChatOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (qianWenOptions == null && prompt.getOptions() == null) {
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = qianWenOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof QianWenOptions)) {
throw new ChatException("Prompt 传入的不是 QianWenOptions!");
}
return (QianWenOptions) options;
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// ctx 会有重试的信息
// 创建 request 请求stream模式需要供应商支持
QwenParam request = this.createRequest(prompt, true);
// 调用 callWithFunctionSupport 发送请求
Flowable<GenerationResult> responseResult = this.qianWenApi.chatCompletionStream(request);
return Flux.create(fluxSink ->
responseResult.subscribe(
value -> fluxSink.next(
new ChatResponse(value.getOutput().getChoices().stream()
.map(choices -> new Generation(choices.getMessage().getContent()))
.collect(Collectors.toList()))
),
error -> fluxSink.error(error),
() -> fluxSink.complete()
)
);
}
}

View File

@ -1,47 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 千问 chat 模型
*
* 模型地址https://help.aliyun.com/document_detail/2712576.html
* 模型介绍https://help.aliyun.com/document_detail/2666503.html?spm=a2c4g.2701795.0.0.26eb34dfKzcWN4
*
* @author fansili
* @time 2024/4/26 10:15
* @since 1.0
*/
@AllArgsConstructor
@Getter
public enum QianWenChatModal {
// 千问付费模型
QWEN_TURBO("通义千问超大规模语言模型", "qwen-turbo"),
QWEN_PLUS("通义千问超大规模语言模型增强版", "qwen-plus"),
QWEN_MAX("通义千问千亿级别超大规模语言模型", "qwen-max"),
QWEN_MAX_0403("通义千问千亿级别超大规模语言模型-0403", "qwen-max-0403"),
QWEN_MAX_0107("通义千问千亿级别超大规模语言模型-0107", "qwen-max-0107"),
QWEN_MAX_1201("通义千问千亿级别超大规模语言模型-1201", "qwen-max-1201"),
QWEN_MAX_LONGCONTEXT("通义千问千亿级别超大规模语言模型-28k tokens", "qwen-max-longcontext"),
// 开源模型
// https://help.aliyun.com/document_detail/2666503.html?spm=a2c4g.2701795.0.0.26eb34dfKzcWN4
QWEN_72B_CHAT("通义千问1.5对外开源的72B规模参数量的经过人类指令对齐的chat模型", "qwen-72b-chat"),
;
private String name;
private String model;
public static QianWenChatModal valueOfModel(String model) {
for (QianWenChatModal itemEnum : QianWenChatModal.values()) {
if (itemEnum.getModel().equals(model)) {
return itemEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -1,119 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi;
import org.springframework.ai.chat.prompt.ChatOptions;
import lombok.Data;
import lombok.experimental.Accessors;
import java.util.List;
/**
* 阿里云 千问 属性
*
* 地址https://help.aliyun.com/document_detail/2684682.html?spm=a2c4g.2621347.0.0.195117e7Ytpkyo
*
* author: fansili
* time: 2024/3/15 19:57
*/
@Data
@Accessors(chain = true)
public class QianWenOptions implements ChatOptions {
/**
* 用户与模型的对话历史
*/
private List<Message> messages;
/**
* 生成时核采样方法的概率阈值例如取值为0.8时仅保留累计概率之和大于等于0.8的概率分布中的token
* 作为随机采样的候选集取值范围为0,1.0)取值越大生成的随机性越高取值越低生成的随机性越低
* 默认值为0.8注意取值不要大于等于1
*/
private Float topP;
/**
* 用于限制模型生成token的数量max_tokens设置的是生成上限并不表示一定会生成这么多的token数量其中qwen1.5-14b-chatqwen1.5-7b-chatqwen-14b-chat和qwen-7b-chat最大值和默认值均为1500qwen-1.8b-chatqwen-1.8b-longcontext-chat和qwen-72b-chat最大值和默认值均为2000
*/
private Integer maxTokens = 1500;
/**
* 模型
*/
private String model;
/**
* temperature
*/
private Float temperature;
//
// 适配 ChatOptions
@Override
public Float getTemperature() {
return null;
}
@Override
public Integer getTopK() {
return null;
}
public Float getTopP() {
return topP;
}
@Data
@Accessors
public static class Message {
/**
* 角色: systemuser或assistant
*/
private String role;
/**
* 提示词或模型内容
*/
private String content;
}
@Data
@Accessors
public static class Parameters {
/**
* 输出格式, 默认为"text"
* "text"表示旧版本的text
* "message"表示兼容openai的message
*/
private String resultFormat;
/**
* 生成时采样候选集的大小例如取值为50时仅将单次生成中得分最高的50个token组成随机采样的候选集
* 取值越大生成的随机性越高取值越小生成的确定性越高
* 注意如果top_k参数为空或者top_k的值大于100表示不启用top_k策略此时仅有top_p策略生效默认是空
*/
private Integer topK;
/**
* 生成时使用的随机数种子用户控制模型生成内容的随机性
* seed支持无符号64位整数默认值为1234在使用seed时模型将尽可能生成相同或相似的结果但目前不保证每次生成的结果完全相同
*/
private Integer seed;
/**
* 用于控制随机性和多样性的程度具体来说temperature值控制了生成文本时对每个候选词的概率分布进行平滑的程度
* 较高的temperature值会降低概率分布的峰值使得更多的低概率词被选择
* 生成结果更加多样化而较低的temperature值则会增强概率分布的峰值使得高概率词更容易被选择生成结果更加确定
* 取值范围 [0, 2)系统默认值1.0不建议取值为0无意义
*/
private Float temperature;
/**
* 用于限制模型生成token的数量max_tokens设置的是生成上限并不表示一定会生成这么多的token数量
* 其中qwen-turbo 最大值和默认值为1500 qwen-maxqwen-max-1201 qwen-max-longcontext qwen-plus最大值和默认值均为2000
*/
private Integer maxTokens;
/**
* stop参数用于实现内容生成过程的精确控制在生成内容即将包含指定的字符串或token_ids时自动停止生成内容不包含指定的内容
* 例如如果指定stop为"你好"表示将要生成"你好"时停止如果指定stop为[37763, 367]表示将要生成"Observation"时停止
*/
private List<String> stop;
/**
* 用于控制流式输出模式默认False即后面内容会包含已经输出的内容设置为True将开启增量输出模式
* 后面输出不会包含已经输出的内容您需要自行拼接整体输出参考流式输出示例代码
*/
private Boolean incrementalOutput;
}
}

View File

@ -1,60 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi.api;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.exception.AiException;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import io.reactivex.Flowable;
import lombok.Getter;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.ResponseEntity;
// TODO done @fansili是不是挪到 api 包里按照 spring ai 的结构根目录只放 client options
/**
* 阿里 通义千问
* <p>
* author: fansili
* time: 2024/3/13 21:09
*/
@Getter
public class QianWenApi {
// api key 获取地址https://bailian.console.aliyun.com/?spm=5176.28197581.0.0.38db29a4G3GcVb&apiKey=1#/api-key
private String apiKey = "sk-Zsd81gZYg7";
private Generation gen = new Generation();
private QianWenChatModal qianWenChatModal;
public QianWenApi(String apiKey, QianWenChatModal qianWenChatModal) {
this.apiKey = apiKey;
this.qianWenChatModal = qianWenChatModal;
}
public ResponseEntity<GenerationResult> chatCompletionEntity(QwenParam request) {
GenerationResult call;
try {
call = gen.call(request);
} catch (NoApiKeyException e) {
throw new AiException("没有找到apiKey" + e.getMessage());
} catch (InputRequiredException e) {
throw new AiException("chat缺少必填字段" + e.getMessage());
}
// 阿里云的这个 http code 随便设置外面判断是否成功用的 CompletionsResponse.isSuccess
return new ResponseEntity<>(call, HttpStatusCode.valueOf(200));
}
public Flowable<GenerationResult> chatCompletionStream(QwenParam request) {
Flowable<GenerationResult> resultFlowable;
try {
resultFlowable = gen.streamCall(request);
} catch (NoApiKeyException e) {
throw new AiException("没有找到apiKey" + e.getMessage());
} catch (InputRequiredException e) {
throw new AiException("chat缺少必填字段" + e.getMessage());
}
return resultFlowable;
}
}

View File

@ -1,9 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi.api;
/**
* author: fansili
* time: 2024/3/13 21:07
*/
public class QianWenChatCompletion {
}

View File

@ -1,8 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi.api;
/**
* author: fansili
* time: 2024/3/13 21:07
*/
public class QianWenChatCompletionMessage {
}

View File

@ -1,16 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.tongyi.api;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
/**
* 千问
*
* author: fansili
* time: 2024/3/13 21:07
*/
public class QianWenChatCompletionRequest extends QwenParam {
protected QianWenChatCompletionRequest(QwenParamBuilder<?, ?> b) {
super(b);
}
}

View File

@ -1,11 +0,0 @@
/**
* 阿里的 通义千问
*
* 链接https://www.aliyun.com/search?k=%E9%80%9A%E4%B9%89%E5%A4%A7%E6%A8%A1%E5%9E%8B&scene=all
*
* 千问所有模型https://bailian.console.aliyun.com/?spm=5176.28515448.J_TC9GqcHi2edq9zUs9ZsDQ.1.417338b17zJTjy#/efm/my_model
*
* author: fansili
* time: 2024/3/13 21:05
*/
package cn.iocoder.yudao.framework.ai.core.model.tongyi;

View File

@ -6,10 +6,13 @@ import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletion;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoChatCompletionRequest;
import org.springframework.ai.chat.*;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
@ -29,7 +32,7 @@ import java.util.stream.Collectors;
* time: 2024/3/11 10:19
*/
@Slf4j
public class XingHuoChatClient implements ChatClient, StreamingChatClient {
public class XingHuoChatClient implements ChatModel, StreamingChatModel {
private XingHuoApi xingHuoApi;
@ -64,7 +67,6 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
@Override
public ChatResponse call(Prompt prompt) {
return this.retryTemplate.execute(ctx -> {
// ctx 会有重试的信息
// 获取 chatOptions 属性
@ -78,6 +80,12 @@ public class XingHuoChatClient implements ChatClient, StreamingChatClient {
});
}
@Override
public ChatOptions getDefaultOptions() {
// TODO 芋艿需要跟进下
throw new UnsupportedOperationException();
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
// 获取 chatOptions 属性

View File

@ -134,7 +134,7 @@ public class XingHuoApi {
URI uri = URI.create(authUrl);
// 发起 wss 请求并处理响应
// 创建一个 Flux 来处理接收到的消息
Flux<XingHuoChatCompletion> messageFlux = Flux.create(sink -> {
return Flux.create(sink -> {
socketClient.execute(uri, session ->
session.send(Mono.just(session.textMessage(JSONUtil.toJsonStr(request))))
.thenMany(session.receive()
@ -145,6 +145,5 @@ public class XingHuoApi {
.then())
.subscribe(); // 订阅以开始会话
});
return messageFlux;
}
}

View File

@ -1,152 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import cn.hutool.core.bean.BeanUtil;
import cn.iocoder.yudao.framework.ai.core.exception.ChatException;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionResponse;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.RetryCallback;
import org.springframework.retry.RetryContext;
import org.springframework.retry.RetryListener;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import java.time.Duration;
import java.util.List;
import java.util.stream.Collectors;
/**
* 文心一言的 {@link ChatClient} 实现类
*
* @author fansili
*/
@Slf4j
public class YiYanChatClient implements ChatClient, StreamingChatClient {
private final YiYanApi yiYanApi;
private YiYanChatOptions defaultOptions;
// TODO @fan参考 OpenAiChatClient 调整下 retryTemplate使用 RetryUtils.DEFAULT_RETRY_TEMPLATE加允许传入
public YiYanChatClient(YiYanApi yiYanApi) {
this.yiYanApi = yiYanApi;
// TODO @fan这个情况是不是搞个 defaultOptionsOpenAiChatOptions.builder().withModel(OpenAiApi.DEFAULT_CHAT_MODEL).withTemperature(0.7f).build()
}
public YiYanChatClient(YiYanApi yiYanApi, YiYanChatOptions defaultOptions) {
Assert.notNull(yiYanApi, "OllamaApi must not be null");
Assert.notNull(defaultOptions, "DefaultOptions must not be null");
this.yiYanApi = yiYanApi;
this.defaultOptions = defaultOptions;
}
public final RetryTemplate retryTemplate = RetryTemplate.builder()
.maxAttempts(10)
.retryOn(YiYanApiException.class)
.exponentialBackoff(Duration.ofMillis(3000), 2, Duration.ofMillis(3 * 60000))
.withListener(new RetryListener() {
@Override
public <T, E extends Throwable> void onError(RetryContext context,
RetryCallback<T, E> callback, Throwable throwable) {
log.warn("重试异常:" + context.getRetryCount(), throwable);
}
})
.build();
@Override
public ChatResponse call(Prompt prompt) {
YiYanChatCompletionRequest request = createRequest(prompt, false);
return this.retryTemplate.execute(ctx -> {
// 发送请求
ResponseEntity<YiYanChatCompletionResponse> response = yiYanApi.chatCompletionEntity(request);
// 获取结果封装 ChatResponse
YiYanChatCompletionResponse chatCompletion = response.getBody();
// TODO @fan为空时参考 OpenAiChatClient 的封装
// TODO @fanchatResponseMetadata参考 OpenAiChatResponseMetadata.from(completionEntity.getBody())
return new ChatResponse(List.of(new Generation(chatCompletion.getResult())));
});
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
YiYanChatCompletionRequest request = this.createRequest(prompt, true);
// TODO done @fanreturn this.retryTemplate.execute(ctx -> {
return retryTemplate.execute(ctx -> {
// 调用 callWithFunctionSupport 发送请求
Flux<YiYanChatCompletionResponse> response = this.yiYanApi.chatCompletionStream(request);
return response.map(chunk -> {
// System.err.println("---".concat(chunk.getResult()));
// TODO @fanChatResponseMetadata chatResponseMetadata
return new ChatResponse(List.of(new Generation(chunk.getResult())));
});
});
}
private YiYanChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
// 参考 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t 文档system 是独立字段
// 1.1 获取 user assistant
List<YiYanChatCompletionRequest.Message> messageList = prompt.getInstructions().stream()
// 过滤 system
.filter(msg -> MessageType.SYSTEM != msg.getMessageType())
.map(message -> new YiYanChatCompletionRequest.Message()
.setRole(message.getMessageType().getValue()).setContent(message.getContent())
).toList();
// 1.2 获取 system
String systemPrompt = prompt.getInstructions().stream()
.filter(message -> MessageType.SYSTEM == message.getMessageType())
.map(Message::getContent)
.collect(Collectors.joining());
// 3. 创建 request
YiYanChatCompletionRequest request = new YiYanChatCompletionRequest(messageList);
// 复制 YiYanOptions 属性 request 这里 options 属性和 request 基本保持一致
YiYanChatOptions useOptions = getYiYanOptions(prompt);
BeanUtil.copyProperties(useOptions, request);
request.setTopP(useOptions.getTopP())
.setMaxOutputTokens(useOptions.getMaxOutputTokens())
.setTemperature(useOptions.getTemperature())
.setSystem(systemPrompt)
.setStream(stream);
return request;
}
// TODO @fanOptions 的处理参考下 OpenAiChatClient createRequest
private YiYanChatOptions getYiYanOptions(Prompt prompt) {
// 两个都为null 则没有配置文件
if (defaultOptions == null && prompt.getOptions() == null) {
// TODO @fanIllegalArgumentException 参数更好哈
throw new ChatException("ChatOptions 未配置参数!");
}
// 优先使用 Prompt 里面的 ChatOptions
ChatOptions options = defaultOptions;
if (prompt.getOptions() != null) {
options = (ChatOptions) prompt.getOptions();
}
// Prompt 里面是一个 ChatOptions用户可以随意传入这里做一下判断
if (!(options instanceof YiYanChatOptions)) {
// TODO @fanIllegalArgumentException 参数更好哈
// TODO @fan需要兼容 ChatOptionsBuilder 创建出来的
throw new ChatException("Prompt 传入的不是 YiYanOptions!");
}
// 转换 YiYanOptions
return (YiYanChatOptions) options;
}
}

View File

@ -1,91 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatCompletionRequest;
import lombok.Data;
import org.springframework.ai.chat.prompt.ChatOptions;
import java.util.List;
/**
* 文心一言的 {@link ChatOptions} 实现类
*
* 字段说明参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">ERNIE-4.0-8K</a>
*
* @author fansili
*/
@Data
public class YiYanChatOptions implements ChatOptions {
/**
* functions 函数
*/
private List<YiYanChatCompletionRequest.Function> functions;
/**
* temperature
*/
private Float temperature;
/**
* topP
*/
private Float topP;
/**
* 通过对已生成的token增加惩罚减少重复生成的现象
*/
private Float penaltyScore;
/**
* stream 模式请求
*/
private Boolean stream;
/**
* system 提示
*/
private String system;
/**
* 生成停止标识当模型生成结果以stop中某个元素结尾时停止文本生成
*/
private List<String> stop;
/**
* 是否强制关闭实时搜索功能
*/
private Boolean disableSearch;
/**
* 是否开启上角标返回
*/
private Boolean enableCitation;
/**
* 输出最大 token
*/
private Integer maxOutputTokens;
/**
* 响应格式 textjson_object
*/
private String responseFormat;
/**
* 用户id
*/
private String userId;
/**
* 在函数调用场景下提示大模型选择指定的函数非强制说明指定的函数名必须在functions中存在
* tip: ERNIE-4.0-8K 模型没有这个字段
*/
private String toolChoice;
@Override
public Float getTemperature() {
return this.temperature;
}
@Override
public Float getTopP() {
return topP;
}
/**
* 百度么有 topK
*/
@Override
public Integer getTopK() {
return null;
}
}

View File

@ -1,106 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.exception.YiYanApiException;
import cn.iocoder.yudao.framework.common.util.json.JsonUtils;
import org.springframework.http.HttpStatusCode;
import org.springframework.http.ResponseEntity;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
/**
* 文心一言 API
*
* @author fansili
*/
public class YiYanApi {
private static final String DEFAULT_BASE_URL = "https://aip.baidubce.com";
private static final String AUTH_2_TOKEN_URI = "/oauth/2.0/token";
public static final YiYanChatModel DEFAULT_CHAT_MODEL = YiYanChatModel.ERNIE4_0;
private final String appKey;
private final String secretKey;
/**
* TODO fan这个是不是要有个刷新机制哈如果目前不需要可以删除掉 refreshTokenSecondTime整体更简洁
*/
private final String token;
/**
* token 刷新时间()
*/
private int refreshTokenSecondTime;
/**
* 发送请求 webClient
*/
private final WebClient webClient;
/**
* 使用的模型
*/
private final YiYanChatModel useChatModel;
// TODO fan看看是不是去掉 refreshTokenSecondTime 字段
public YiYanApi(String appKey, String secretKey, YiYanChatModel useChatModel, int refreshTokenSecondTime) {
this.appKey = appKey;
this.secretKey = secretKey;
this.useChatModel = useChatModel;
this.refreshTokenSecondTime = refreshTokenSecondTime;
this.webClient = WebClient.builder().baseUrl(DEFAULT_BASE_URL).build();
// 获取访问令牌
token = getToken();
}
/**
* 获得访问令牌
*
* @see <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Ilkkrb0i5">文档地址</>
* @return 访问令牌
*/
private String getToken() {
ResponseEntity<YiYanAuthResponse> response = this.webClient.post()
.uri(uriBuilder -> uriBuilder.path(AUTH_2_TOKEN_URI)
.queryParam("grant_type", "client_credentials")
.queryParam("client_id", appKey)
.queryParam("client_secret", secretKey)
.build()
)
.retrieve()
.toEntity(YiYanAuthResponse.class)
.block();
// 检查请求状态
// TODO @fan可以使用 response.getStatusCode().is2xxSuccessful()
if (HttpStatusCode.valueOf(200) != response.getStatusCode()
|| response.getBody() == null) {
// TODO @fan可以使用 IllegalStateException 替代另外最好打印下返回方便排错
throw new YiYanApiException("一言认证失败! apihttps://aip.baidubce.com/oauth/2.0/token 请检查 client_id、client_secret 是否正确!");
}
return response.getBody().getAccess_token();
}
public ResponseEntity<YiYanChatCompletionResponse> chatCompletionEntity(YiYanChatCompletionRequest request) {
// TODO: 2024/3/10 小范 这里错误信息返回的结构不一样
// {"error_code":17,"error_msg":"Open api daily request limit reached"}
return this.webClient.post()
.uri(uriBuilder
-> uriBuilder.path(useChatModel.getUri())
.queryParam("access_token", token)
.build())
.body(Mono.just(JsonUtils.toJsonString(request)), String.class)
.retrieve()
.toEntity(YiYanChatCompletionResponse.class)
.block();
}
public Flux<YiYanChatCompletionResponse> chatCompletionStream(YiYanChatCompletionRequest request) {
return this.webClient.post()
.uri(uriBuilder
-> uriBuilder.path(useChatModel.getUri())
.queryParam("access_token", token)
.build())
.body(Mono.just(request), YiYanChatCompletionRequest.class)
.retrieve()
.bodyToFlux(YiYanChatCompletionResponse.class);
}
}

View File

@ -1,48 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data;
// TODO @fan字段驼峰字段注释都可以删除贴个链接就好
/**
* 获取文心一言的 access_token Response
*
* @author fansili
*/
@Data
public class YiYanAuthResponse {
/**
* 访问凭证
*/
private String access_token;
/**
* 有效期Access Token的有效期
* 说明单位是秒有效期30天
*/
private int expires_in;
/**
* 错误码说明响应失败时返回该字段成功时不返回
*/
private String error;
/**
* 错误描述信息帮助理解和解决发生的错误
* 说明响应失败时返回该字段成功时不返回
*/
private String error_description;
/**
* 暂时未使用可忽略
*/
private String session_key;
/**
* 暂时未使用可忽略
*/
private String refresh_token;
/**
* 暂时未使用可忽略
*/
private String scope;
/**
* 暂时未使用可忽略
*/
private String session_secret;
}

View File

@ -1,154 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import com.fasterxml.jackson.annotation.JsonProperty;
import lombok.Data;
import java.util.List;
/**
* 文心一言 Completion Request
*
* 百度千帆文档https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
*
* @author fansili
*/
@Data
public class YiYanChatCompletionRequest {
public YiYanChatCompletionRequest(List<Message> messages) {
this.messages = messages;
}
/**
* 聊天上下文信息
*/
private List<Message> messages;
/**
* functions 函数
*/
private List<Function> functions;
/**
* temperature
*/
private Float temperature;
/**
* topP
*/
@JsonProperty("top_p")
private Float topP;
/**
* 通过对已生成的token增加惩罚减少重复生成的现象
*/
@JsonProperty("penalty_score")
private Float penaltyScore;
/**
* stream 模式
*/
private Boolean stream;
/**
* system 预设角色
*/
private String system;
/**
* 生成停止标识当模型生成结果以stop中某个元素结尾时停止文本生成
*/
private List<String> stop;
/**
* 是否强制关闭实时搜索功能
*/
@JsonProperty("disable_search")
private Boolean disableSearch;
/**
* 是否开启上角标返回
*/
@JsonProperty("enable_citation")
private Boolean enableCitation;
/**
* 最大输出 token
*/
@JsonProperty("max_output_tokens")
private Integer maxOutputTokens;
/**
* 返回格式 textjson_object
*/
@JsonProperty("response_format")
private String responseFormat;
/**
* 用户 id
*/
@JsonProperty("user_id")
private String userId;
/**
* 在函数调用场景下提示大模型选择指定的函数非强制说明指定的函数名必须在functions中存在
* tip: ERNIE-4.0-8K 模型没有这个字段
*/
@JsonProperty("tool_choice")
private String toolChoice;
@Data
public static class Message {
private String role;
private String content;
}
@Data
public static class ToolChoice {
/**
* 指定工具类型function
* 必填:
*/
private String type;
/**
* 指定要使用的函数
* 必填:
*/
private Function function;
/**
* 指定要使用的函数名
* 必填:
*/
private String name;
}
@Data
public static class Function {
/**
* 函数名
* 必填:
*/
private String name;
/**
* 函数描述
* 必填:
*/
private String description;
/**
* 函数请求参数说明
* 1JSON Schema 格式参考JSON Schema描述
* 2如果函数没有请求参数parameters值格式如下
* {"type": "object","properties": {}}
* 必填:
*/
private String parameters;
/**
* 函数响应参数JSON Schema 格式参考JSON Schema描述
* 必填:
*/
private String responses;
/**
* function调用的一些历史示例说明
* 1可以提供正例正常触发和反例无需触发的example
* ·正例从历史请求数据中获取
* ·反例
* 当role = user不会触发请求的query
* 当role = assistant有固定的格式function_call的name为空arguments是空对象:"{}"thought可以填固定的:"我不需要调用任何工具"
* 2兼容之前的 List(example) 格式
*/
private String examples;
}
}

View File

@ -1,92 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.Data;
/**
* 文心一言 Completion Response
*
* 百度链接: https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t
*
* @author fansili
*/
@Data
public class YiYanChatCompletionResponse {
/**
* 本轮对话的id
*/
private String id;
/**
* 回包类型chat.completion多轮对话返回
*/
private String object;
/**
* 时间戳
*/
private int created;
/**
* 表示当前子句的序号只有在流式接口模式下会返回该字段
*/
private int sentence_id;
/**
* 表示当前子句是否是最后一句只有在流式接口模式下会返回该字段
*/
private boolean is_end;
/**
* 当前生成的结果是否被截断
*/
private boolean is_truncated;
/**
* 输出内容标识说明
* · normal输出内容完全由大模型生成未触发截断替换
* · stop输出结果命中入参stop中指定的字段后被截断
* · length达到了最大的token数根据EB返回结果is_truncated来截断
* · content_filter输出内容被截断兜底替换为**
*/
private String finish_reason;
/**
* 搜索数据当请求参数enable_citation为true并且触发搜索时会返回该字段
*/
private String search_info;
/**
* 对话返回结果
*/
private String result;
/**
* 表示用户输入是否存在安全是否关闭当前会话清理历史会话信息
* true表示用户输入存在安全风险建议关闭当前会话清理历史会话信息
* false表示用户输入无安全风险
*/
private boolean need_clear_history;
/**
* 说明
* · 0正常返回
* · 其他非正常
*/
private int flag;
/**
* 当need_clear_history为true时此字段会告知第几轮对话有敏感信息如果是当前问题ban_round=-1
*/
private int ban_round;
/**
* token统计信息
*/
private Usage usage;
@Data
public static class Usage {
/**
* 问题tokens数
*/
private int prompt_tokens;
/**
* 回答tokens数
*/
private int completion_tokens;
/**
* tokens总数
*/
private int total_tokens;
}
}

View File

@ -1,42 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.api;
import lombok.AllArgsConstructor;
import lombok.Getter;
/**
* 文心一言的模型枚举
*
* 可参考 <a href="https://cloud.baidu.com/doc/WENXINWORKSHOP/s/clntwmv7t">百度文档</>
*
* @author fansili
*/
@Getter
@AllArgsConstructor
public enum YiYanChatModel {
ERNIE4_0("ERNIE 4.0", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro"),
ERNIE4_3_5_8K("ERNIE-3.5-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions"),
ERNIE4_3_5_8K_0205("ERNIE-3.5-8K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-0205"),
ERNIE4_3_5_8K_1222("ERNIE-3.5-8K-1222", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-8k-1222"),
ERNIE4_BOT_8K("ERNIE-Bot-8K", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie_bot_8k"),
ERNIE4_3_5_4K_0205("ERNIE-3.5-4K-0205", "/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/ernie-3.5-4k-0205"),
;
/**
* 模型名
*/
private final String model;
/**
* API URL
*/
private final String uri;
public static YiYanChatModel valueOfModel(String model) {
for (YiYanChatModel modelEnum : YiYanChatModel.values()) {
if (modelEnum.getModel().equals(model)) {
return modelEnum;
}
}
throw new IllegalArgumentException("Invalid MessageType value: " + model);
}
}

View File

@ -1,16 +0,0 @@
package cn.iocoder.yudao.framework.ai.core.model.yiyan.exception;
/**
* 一言 api 调用异常
*/
public class YiYanApiException extends RuntimeException {
public YiYanApiException(String message) {
super(message);
}
public YiYanApiException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,253 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechModel;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechProperties;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionModel;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionProperties;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatModel;
import com.alibaba.cloud.ai.tongyi.chat.TongYiChatProperties;
import com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingModel;
import com.alibaba.cloud.ai.tongyi.embedding.TongYiTextEmbeddingProperties;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesModel;
import com.alibaba.cloud.ai.tongyi.image.TongYiImagesProperties;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.audio.asr.transcription.Transcription;
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.utils.ApiKey;
import com.alibaba.dashscope.utils.Constants;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.boot.autoconfigure.AutoConfiguration;
import org.springframework.boot.autoconfigure.condition.ConditionalOnClass;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.ApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Scope;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@AutoConfiguration
@ConditionalOnClass({
MessageManager.class,
TongYiChatModel.class,
TongYiImagesModel.class,
TongYiAudioSpeechModel.class,
TongYiTextEmbeddingModel.class,
TongYiAudioTranscriptionModel.class
})
@EnableConfigurationProperties({
TongYiChatProperties.class,
TongYiImagesProperties.class,
TongYiAudioSpeechProperties.class,
TongYiConnectionProperties.class,
TongYiTextEmbeddingProperties.class,
TongYiAudioTranscriptionProperties.class
})
public class TongYiAutoConfiguration {
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public Generation generation() {
return new Generation();
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public MessageManager msgManager() {
return new MessageManager(10);
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public ImageSynthesis imageSynthesis() {
return new ImageSynthesis();
}
@Bean
@Scope("prototype")
@ConditionalOnMissingBean
public SpeechSynthesizer speechSynthesizer() {
return new SpeechSynthesizer();
}
@Bean
@ConditionalOnMissingBean
public Transcription transcription() {
return new Transcription();
}
@Bean
@ConditionalOnMissingBean
public TextEmbedding textEmbedding() {
return new TextEmbedding();
}
@Bean
@ConditionalOnMissingBean
public FunctionCallbackContext springAiFunctionManager(ApplicationContext context) {
FunctionCallbackContext manager = new FunctionCallbackContext();
manager.setApplicationContext(context);
return manager;
}
@Bean
@ConditionalOnProperty(
prefix = TongYiChatProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiChatModel tongYiChatClient(Generation generation,
TongYiChatProperties chatOptions,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiChatModel(generation, chatOptions.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiImagesProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiImagesModel tongYiImagesClient(
ImageSynthesis imageSynthesis,
TongYiImagesProperties imagesOptions,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiImagesModel(imageSynthesis, imagesOptions.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiAudioSpeechProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiAudioSpeechModel tongYiAudioSpeechClient(
SpeechSynthesizer speechSynthesizer,
TongYiAudioSpeechProperties speechProperties,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiAudioSpeechModel(speechSynthesizer, speechProperties.getOptions());
}
@Bean
@ConditionalOnProperty(
prefix = TongYiAudioTranscriptionProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiAudioTranscriptionModel tongYiAudioTranscriptionClient(
Transcription transcription,
TongYiAudioTranscriptionProperties transcriptionProperties,
TongYiConnectionProperties connectionProperties) {
settingApiKey(connectionProperties);
return new TongYiAudioTranscriptionModel(
transcriptionProperties.getOptions(),
transcription
);
}
@Bean
@ConditionalOnProperty(
prefix = TongYiTextEmbeddingProperties.CONFIG_PREFIX,
name = "enabled",
havingValue = "true",
matchIfMissing = true
)
public TongYiTextEmbeddingModel tongYiTextEmbeddingClient(
TextEmbedding textEmbedding,
TongYiConnectionProperties connectionProperties
) {
settingApiKey(connectionProperties);
return new TongYiTextEmbeddingModel(textEmbedding);
}
/**
* Setting the API key.
* @param connectionProperties {@link TongYiConnectionProperties}
*/
private void settingApiKey(TongYiConnectionProperties connectionProperties) {
String apiKey;
try {
// It is recommended to set the key by defining the api-key in an environment variable.
var envKey = System.getenv(TongYiConstants.SCA_AI_TONGYI_API_KEY);
if (Objects.nonNull(envKey)) {
Constants.apiKey = envKey;
return;
}
if (Objects.nonNull(connectionProperties.getApiKey())) {
apiKey = connectionProperties.getApiKey();
}
else {
apiKey = ApiKey.getApiKey(null);
}
Constants.apiKey = apiKey;
}
catch (NoApiKeyException e) {
throw new TongYiException(e.getMessage());
}
}
}

View File

@ -0,0 +1,52 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi;
import org.springframework.boot.context.properties.ConfigurationProperties;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* Spring Cloud Alibaba AI TongYi LLM connection properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiConnectionProperties.CONFIG_PREFIX)
public class TongYiConnectionProperties {
/**
* Spring Cloud Alibaba AI connection configuration Prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "tongyi";
/**
* TongYi LLM API key.
*/
private String apiKey;
public String getApiKey() {
return apiKey;
}
public void setApiKey(String apiKey) {
this.apiKey = apiKey;
}
}

View File

@ -0,0 +1,40 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio;
/**
* More models see: https://help.aliyun.com/zh/dashscope/model-list?spm=a2c4g.11186623.0.i5
* Support all models in list.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public final class AudioSpeechModels {
private AudioSpeechModels() {
}
/**
* Male Voice of the Tongue(舌尖男声).
* zh & en.
* Default sample rate: 48 Hz.
*/
public static final String SAMBERT_ZHICHU_V1 = "sambert-zhichu-v1";
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public final class AudioTranscriptionModels {
private AudioTranscriptionModels() {
}
/**
* Paraformer Chinese and English speech recognition model supports audio or video speech recognition with a sampling rate of 16kHz or above.
*/
public static final String Paraformer_V1 = "paraformer-v1";
/**
* Paraformer Chinese speech recognition model, support 8kHz telephone speech recognition.
*/
public static final String Paraformer_8K_V1 = "paraformer-8k-v1";
/**
* The Paraformer multilingual speech recognition model supports audio or video speech recognition with a sample rate of 16kHz or above.
*/
public static final String Paraformer_MTL_V1 = "paraformer-mtl-v1";
}

View File

@ -0,0 +1,228 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.cloud.ai.tongyi.audio.speech.api.*;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisParam;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
import com.alibaba.dashscope.audio.tts.SpeechSynthesizer;
import com.alibaba.dashscope.common.ResultCallback;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.util.Assert;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.nio.ByteBuffer;
/**
* TongYiAudioSpeechClient is a client for TongYi audio speech service for Spring Cloud Alibaba AI.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechModel implements SpeechModel, SpeechStreamModel {
private final Logger logger = LoggerFactory.getLogger(getClass());
/**
* Default speed rate.
*/
private static final float SPEED_RATE = 1.0f;
/**
* TongYi models api.
*/
private final SpeechSynthesizer speechSynthesizer;
/**
* TongYi models options.
*/
private final TongYiAudioSpeechOptions defaultOptions;
/**
* TongYiAudioSpeechClient constructor.
* @param speechSynthesizer the speech synthesizer
*/
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer) {
this(speechSynthesizer, null);
}
/**
* TongYiAudioSpeechClient constructor.
* @param speechSynthesizer the speech synthesizer
* @param tongYiAudioOptions the tongYi audio options
*/
public TongYiAudioSpeechModel(SpeechSynthesizer speechSynthesizer, TongYiAudioSpeechOptions tongYiAudioOptions) {
Assert.notNull(speechSynthesizer, "speechSynthesizer must not be null");
Assert.notNull(tongYiAudioOptions, "tongYiAudioOptions must not be null");
this.speechSynthesizer = speechSynthesizer;
this.defaultOptions = tongYiAudioOptions;
}
/**
* Call the TongYi audio speech service.
* @param text the text message to be converted to audio.
* @return the audio byte buffer.
*/
@Override
public ByteBuffer call(String text) {
var speechRequest = new SpeechPrompt(text);
return call(speechRequest).getResult().getOutput();
}
/**
* Call the TongYi audio speech service.
* @param prompt the speech prompt.
* @return the speech response.
*/
@Override
public SpeechResponse call(SpeechPrompt prompt) {
var SCASpeechParam = merge(prompt.getOptions());
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
speechSynthesisParams.setText(prompt.getInstructions().getText());
logger.info(speechSynthesisParams.toString());
var res = speechSynthesizer.call(speechSynthesisParams);
return convert(res, null);
}
/**
* Call the TongYi audio speech service.
* @param prompt the speech prompt.
* @param callback the result callback.
* {@link SpeechSynthesizer#call(SpeechSynthesisParam, ResultCallback)}
*/
public void call(SpeechPrompt prompt, ResultCallback<SpeechSynthesisResult> callback) {
var SCASpeechParam = merge(prompt.getOptions());
var speechSynthesisParams = toSpeechSynthesisParams(SCASpeechParam);
speechSynthesisParams.setText(prompt.getInstructions().getText());
speechSynthesizer.call(speechSynthesisParams, callback);
}
/**
* Stream the TongYi audio speech service.
* @param prompt the speech prompt.
* @return the speech response.
* {@link SpeechSynthesizer#streamCall(SpeechSynthesisParam)}
*/
@Override
public Flux<SpeechResponse> stream(SpeechPrompt prompt) {
var SCASpeechParam = merge(prompt.getOptions());
Flowable<SpeechSynthesisResult> resultFlowable = speechSynthesizer
.streamCall(toSpeechSynthesisParams(SCASpeechParam));
return Flux.from(resultFlowable)
.flatMap(
res -> Flux.just(res.getAudioFrame())
.map(audio -> {
var speech = new Speech(audio);
var respMetadata = TongYiAudioSpeechResponseMetadata.from(res);
return new SpeechResponse(speech, respMetadata);
})
).publishOn(Schedulers.parallel());
}
public TongYiAudioSpeechOptions merge(TongYiAudioSpeechOptions target) {
var mergeBuilder = TongYiAudioSpeechOptions.builder();
mergeBuilder.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel() : target.getModel());
mergeBuilder.withPitch(defaultOptions.getPitch() != null ? defaultOptions.getPitch() : target.getPitch());
mergeBuilder.withRate(defaultOptions.getRate() != null ? defaultOptions.getRate() : target.getRate());
mergeBuilder.withFormat(defaultOptions.getFormat() != null ? defaultOptions.getFormat() : target.getFormat());
mergeBuilder.withSampleRate(defaultOptions.getSampleRate() != null ? defaultOptions.getSampleRate() : target.getSampleRate());
mergeBuilder.withTextType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : target.getTextType());
mergeBuilder.withVolume(defaultOptions.getVolume() != null ? defaultOptions.getVolume() : target.getVolume());
mergeBuilder.withEnablePhonemeTimestamp(defaultOptions.isEnablePhonemeTimestamp() != null ? defaultOptions.isEnablePhonemeTimestamp() : target.isEnablePhonemeTimestamp());
mergeBuilder.withEnableWordTimestamp(defaultOptions.isEnableWordTimestamp() != null ? defaultOptions.isEnableWordTimestamp() : target.isEnableWordTimestamp());
return mergeBuilder.build();
}
public SpeechSynthesisParam toSpeechSynthesisParams(TongYiAudioSpeechOptions source) {
var mergeBuilder = SpeechSynthesisParam.builder();
mergeBuilder.model(source.getModel() != null ? source.getModel() : AudioSpeechModels.SAMBERT_ZHICHU_V1);
mergeBuilder.text(source.getText() != null ? source.getText() : "");
if (source.getFormat() != null) {
mergeBuilder.format(source.getFormat());
}
if (source.getRate() != null) {
mergeBuilder.rate(source.getRate());
}
if (source.getPitch() != null) {
mergeBuilder.pitch(source.getPitch());
}
if (source.getTextType() != null) {
mergeBuilder.textType(source.getTextType());
}
if (source.getSampleRate() != null) {
mergeBuilder.sampleRate(source.getSampleRate());
}
if (source.isEnablePhonemeTimestamp() != null) {
mergeBuilder.enablePhonemeTimestamp(source.isEnablePhonemeTimestamp());
}
if (source.isEnableWordTimestamp() != null) {
mergeBuilder.enableWordTimestamp(source.isEnableWordTimestamp());
}
if (source.getVolume() != null) {
mergeBuilder.volume(source.getVolume());
}
return mergeBuilder.build();
}
/**
* Convert the TongYi audio speech service result to the speech response.
* @param result the audio byte buffer.
* @param synthesisResult the synthesis result.
* @return the speech response.
*/
private SpeechResponse convert(ByteBuffer result, SpeechSynthesisResult synthesisResult) {
if (synthesisResult == null) {
return new SpeechResponse(new Speech(result));
}
var responseMetadata = TongYiAudioSpeechResponseMetadata.from(synthesisResult);
var speech = new Speech(synthesisResult.getAudioFrame());
return new SpeechResponse(speech, responseMetadata);
}
}

View File

@ -0,0 +1,261 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisTextType;
import org.springframework.ai.model.ModelOptions;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechOptions implements ModelOptions {
/**
* Audio Speech models.
*/
private String model = AudioSpeechModels.SAMBERT_ZHICHU_V1;
/**
* Text content.
*/
private String text;
/**
* Input text type.
*/
private SpeechSynthesisTextType textType = SpeechSynthesisTextType.PLAIN_TEXT;
/**
* synthesis audio format.
*/
private SpeechSynthesisAudioFormat format = SpeechSynthesisAudioFormat.WAV;
/**
* synthesis audio sample rate.
*/
private Integer sampleRate = 16000;
/**
* synthesis audio volume.
*/
private Integer volume = 50;
/**
* synthesis audio speed.
*/
private Float rate = 1.0f;
/**
* synthesis audio pitch.
*/
private Float pitch = 1.0f;
/**
* enable word level timestamp.
*/
private Boolean enableWordTimestamp = false;
/**
* enable phoneme level timestamp.
*/
private Boolean enablePhonemeTimestamp = false;
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public String getText() {
return text;
}
public void setText(String text) {
this.text = text;
}
public SpeechSynthesisTextType getTextType() {
return textType;
}
public void setTextType(SpeechSynthesisTextType textType) {
this.textType = textType;
}
public SpeechSynthesisAudioFormat getFormat() {
return format;
}
public void setFormat(SpeechSynthesisAudioFormat format) {
this.format = format;
}
public Integer getSampleRate() {
return sampleRate;
}
public void setSampleRate(Integer sampleRate) {
this.sampleRate = sampleRate;
}
public Integer getVolume() {
return volume;
}
public void setVolume(Integer volume) {
this.volume = volume;
}
public Float getRate() {
return rate;
}
public void setRate(Float rate) {
this.rate = rate;
}
public Float getPitch() {
return pitch;
}
public void setPitch(Float pitch) {
this.pitch = pitch;
}
public Boolean isEnableWordTimestamp() {
return enableWordTimestamp;
}
public void setEnableWordTimestamp(Boolean enableWordTimestamp) {
this.enableWordTimestamp = enableWordTimestamp;
}
public Boolean isEnablePhonemeTimestamp() {
return enablePhonemeTimestamp;
}
public void setEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
this.enablePhonemeTimestamp = enablePhonemeTimestamp;
}
/**
* Build a options instances.
*/
public static class Builder {
private final TongYiAudioSpeechOptions options = new TongYiAudioSpeechOptions();
public Builder withModel(String model) {
options.model = model;
return this;
}
public Builder withText(String text) {
options.text = text;
return this;
}
public Builder withTextType(SpeechSynthesisTextType textType) {
options.textType = textType;
return this;
}
public Builder withFormat(SpeechSynthesisAudioFormat format) {
options.format = format;
return this;
}
public Builder withSampleRate(Integer sampleRate) {
options.sampleRate = sampleRate;
return this;
}
public Builder withVolume(Integer volume) {
options.volume = volume;
return this;
}
public Builder withRate(Float rate) {
options.rate = rate;
return this;
}
public Builder withPitch(Float pitch) {
options.pitch = pitch;
return this;
}
public Builder withEnableWordTimestamp(Boolean enableWordTimestamp) {
options.enableWordTimestamp = enableWordTimestamp;
return this;
}
public Builder withEnablePhonemeTimestamp(Boolean enablePhonemeTimestamp) {
options.enablePhonemeTimestamp = enablePhonemeTimestamp;
return this;
}
public TongYiAudioSpeechOptions build() {
return options;
}
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech;
import com.alibaba.cloud.ai.tongyi.audio.AudioSpeechModels;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisAudioFormat;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* TongYi audio speech configuration properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiAudioSpeechProperties.CONFIG_PREFIX)
public class TongYiAudioSpeechProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.speech";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioSpeechModels.SAMBERT_ZHICHU_V1;
/**
* Enable TongYiQWEN ai audio client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiAudioSpeechOptions options = TongYiAudioSpeechOptions.builder()
.withModel(DEFAULT_AUDIO_MODEL_NAME)
.withFormat(SpeechSynthesisAudioFormat.WAV)
.build();
public TongYiAudioSpeechOptions getOptions() {
return this.options;
}
public void setOptions(TongYiAudioSpeechOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@ -0,0 +1,87 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.ModelResult;
import org.springframework.lang.Nullable;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class Speech implements ModelResult<ByteBuffer> {
private final ByteBuffer audio;
private SpeechMetadata speechMetadata;
public Speech(ByteBuffer audio) {
this.audio = audio;
}
@Override
public ByteBuffer getOutput() {
return this.audio;
}
@Override
public SpeechMetadata getMetadata() {
return speechMetadata != null ? speechMetadata : SpeechMetadata.NULL;
}
public Speech withSpeechMetadata(@Nullable SpeechMetadata speechMetadata) {
this.speechMetadata = speechMetadata;
return this;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof Speech that)) {
return false;
}
return Arrays.equals(audio.array(), that.audio.array())
&& Objects.equals(speechMetadata, that.speechMetadata);
}
@Override
public int hashCode() {
return Objects.hash(Arrays.hashCode(audio.array()), speechMetadata);
}
@Override
public String toString() {
return "Speech{" + "text=" + audio + ", speechMetadata=" + speechMetadata + '}';
}
}

View File

@ -0,0 +1,80 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import java.util.Objects;
/**
* The {@link SpeechMessage} class represents a single text message to
* be converted to speech by the TongYi LLM TTS.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechMessage {
private String text;
/**
* Constructs a new {@link SpeechMessage} object with the given text.
* @param text the text to be converted to speech
*/
public SpeechMessage(String text) {
this.text = text;
}
/**
* Returns the text of this speech message.
* @return the text of this speech message
*/
public String getText() {
return text;
}
/**
* Sets the text of this speech message.
* @param text the new text for this speech message
*/
public void setText(String text) {
this.text = text;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechMessage that)) {
return false;
}
return Objects.equals(text, that.text);
}
@Override
public int hashCode() {
return Objects.hash(text);
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.ResultMetadata;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public interface SpeechMetadata extends ResultMetadata {
/**
* Null Object.
*/
SpeechMetadata NULL = SpeechMetadata.create();
/**
* Factory method used to construct a new {@link SpeechMetadata}.
* @return a new {@link SpeechMetadata}
*/
static SpeechMetadata create() {
return new SpeechMetadata() {
};
}
}

View File

@ -0,0 +1,51 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.Model;
import java.nio.ByteBuffer;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.0.0-RC1
*/
@FunctionalInterface
public interface SpeechModel extends Model<SpeechPrompt, SpeechResponse> {
/**
* Generates spoken audio from the provided text message.
* @param message the text message to be converted to audio.
* @return the resulting audio bytes.
*/
default ByteBuffer call(String message) {
SpeechPrompt prompt = new SpeechPrompt(message);
return call(prompt).getResult().getOutput();
}
/**
* Sends a speech request to the TongYi TTS API and returns the resulting speech response.
* @param request the speech prompt containing the input text and other parameters.
* @return the speech response containing the generated audio.
*/
SpeechResponse call(SpeechPrompt request);
}

View File

@ -0,0 +1,89 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import com.alibaba.cloud.ai.tongyi.audio.speech.TongYiAudioSpeechOptions;
import org.springframework.ai.model.ModelRequest;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechPrompt implements ModelRequest<SpeechMessage> {
private TongYiAudioSpeechOptions speechOptions;
private final SpeechMessage message;
public SpeechPrompt(String instructions) {
this(new SpeechMessage(instructions), TongYiAudioSpeechOptions.builder().build());
}
public SpeechPrompt(String instructions, TongYiAudioSpeechOptions speechOptions) {
this(new SpeechMessage(instructions), speechOptions);
}
public SpeechPrompt(SpeechMessage speechMessage) {
this(speechMessage, TongYiAudioSpeechOptions.builder().build());
}
public SpeechPrompt(SpeechMessage speechMessage, TongYiAudioSpeechOptions speechOptions) {
this.message = speechMessage;
this.speechOptions = speechOptions;
}
@Override
public SpeechMessage getInstructions() {
return this.message;
}
@Override
public TongYiAudioSpeechOptions getOptions() {
return speechOptions;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechPrompt that)) {
return false;
}
return Objects.equals(speechOptions, that.speechOptions) && Objects.equals(message, that.message);
}
@Override
public int hashCode() {
return Objects.hash(speechOptions, message);
}
}

View File

@ -0,0 +1,100 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioSpeechResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class SpeechResponse implements ModelResponse<Speech> {
private final Speech speech;
private final TongYiAudioSpeechResponseMetadata speechResponseMetadata;
/**
* Creates a new instance of SpeechResponse with the given speech result.
* @param speech the speech result to be set in the SpeechResponse
* @see Speech
*/
public SpeechResponse(Speech speech) {
this(speech, TongYiAudioSpeechResponseMetadata.NULL);
}
/**
* Creates a new instance of SpeechResponse with the given speech result and speech
* response metadata.
* @param speech the speech result to be set in the SpeechResponse
* @param speechResponseMetadata the speech response metadata to be set in the
* SpeechResponse
* @see Speech
* @see TongYiAudioSpeechResponseMetadata
*/
public SpeechResponse(Speech speech, TongYiAudioSpeechResponseMetadata speechResponseMetadata) {
this.speech = speech;
this.speechResponseMetadata = speechResponseMetadata;
}
@Override
public Speech getResult() {
return speech;
}
@Override
public List<Speech> getResults() {
return Collections.singletonList(speech);
}
@Override
public TongYiAudioSpeechResponseMetadata getMetadata() {
return speechResponseMetadata;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof SpeechResponse that)) {
return false;
}
return Objects.equals(speech, that.speech)
&& Objects.equals(speechResponseMetadata, that.speechResponseMetadata);
}
@Override
public int hashCode() {
return Objects.hash(speech, speechResponseMetadata);
}
}

View File

@ -0,0 +1,54 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.speech.api;
import org.springframework.ai.model.StreamingModel;
import reactor.core.publisher.Flux;
import java.nio.ByteBuffer;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@FunctionalInterface
public interface SpeechStreamModel extends StreamingModel<SpeechPrompt, SpeechResponse> {
/**
* Generates a stream of audio bytes from the provided text message.
*
* @param message the text message to be converted to audio
* @return a Flux of audio bytes representing the generated speech
*/
default Flux<ByteBuffer> stream(String message) {
SpeechPrompt prompt = new SpeechPrompt(message);
return stream(prompt).map(SpeechResponse::getResult).map(Speech::getOutput);
}
/**
* Sends a speech request to the TongYi TTS API and returns a stream of the resulting
* speech responses.
* @param prompt the speech prompt containing the input text and other parameters.
* @return a Flux of speech responses, each containing a portion of the generated audio.
*/
@Override
Flux<SpeechResponse> stream(SpeechPrompt prompt);
}

View File

@ -0,0 +1,186 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionPrompt;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResponse;
import com.alibaba.cloud.ai.tongyi.audio.transcription.api.AudioTranscriptionResult;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
import com.alibaba.dashscope.audio.asr.transcription.*;
import org.springframework.ai.model.Model;
import org.springframework.core.io.Resource;
import org.springframework.util.Assert;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
/**
* TongYiAudioTranscriptionModel is a client for TongYi audio transcription service for
* Spring Cloud Alibaba AI.
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionModel
implements Model<AudioTranscriptionPrompt, AudioTranscriptionResponse> {
/**
* TongYi models options.
*/
private final TongYiAudioTranscriptionOptions defaultOptions;
/**
* TongYi models api.
*/
private final Transcription transcription;
public TongYiAudioTranscriptionModel(Transcription transcription) {
this(null, transcription);
}
public TongYiAudioTranscriptionModel(TongYiAudioTranscriptionOptions defaultOptions,
Transcription transcription) {
Assert.notNull(transcription, "transcription must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
this.defaultOptions = defaultOptions;
this.transcription = transcription;
}
@Override
public AudioTranscriptionResponse call(AudioTranscriptionPrompt prompt) {
TranscriptionParam transcriptionParam;
if (prompt.getOptions() != null) {
var param = merge(prompt.getOptions());
transcriptionParam = toTranscriptionParam(param);
transcriptionParam.setFileUrls(prompt.getOptions().getFileUrls());
}
else {
Resource instructions = prompt.getInstructions();
try {
transcriptionParam = TranscriptionParam.builder()
.model(AudioTranscriptionModels.Paraformer_V1)
.fileUrls(List.of(String.valueOf(instructions.getURL())))
.build();
}
catch (IOException e) {
throw new TongYiException("Failed to create resource", e);
}
}
List<TranscriptionTaskResult> taskResultList;
try {
// Submit a transcription request
TranscriptionResult result = transcription.asyncCall(transcriptionParam);
// Wait for the transcription to complete
result = transcription.wait(TranscriptionQueryParam
.FromTranscriptionParam(transcriptionParam, result.getTaskId()));
// Get the transcription results
System.out.println(result.getOutput().getAsJsonObject());
taskResultList = result.getResults();
System.out.println(Arrays.toString(taskResultList.toArray()));
return new AudioTranscriptionResponse(
taskResultList.stream().map(taskResult ->
new AudioTranscriptionResult(taskResult.getTranscriptionUrl())
).collect(Collectors.toList()),
TongYiAudioTranscriptionResponseMetadata.from(result)
);
}
catch (Exception e) {
throw new TongYiException("Failed to call audio transcription", e);
}
}
public TongYiAudioTranscriptionOptions merge(TongYiAudioTranscriptionOptions target) {
var mergeBuilder = TongYiAudioTranscriptionOptions.builder();
mergeBuilder
.withModel(defaultOptions.getModel() != null ? defaultOptions.getModel()
: target.getModel());
mergeBuilder.withChannelId(
defaultOptions.getChannelId() != null ? defaultOptions.getChannelId()
: target.getChannelId());
mergeBuilder.withDiarizationEnabled(defaultOptions.getDiarizationEnabled() != null
? defaultOptions.getDiarizationEnabled()
: target.getDiarizationEnabled());
mergeBuilder.withDisfluencyRemovalEnabled(
defaultOptions.getDisfluencyRemovalEnabled() != null
? defaultOptions.getDisfluencyRemovalEnabled()
: target.getDisfluencyRemovalEnabled());
mergeBuilder.withTimestampAlignmentEnabled(
defaultOptions.getTimestampAlignmentEnabled() != null
? defaultOptions.getTimestampAlignmentEnabled()
: target.getTimestampAlignmentEnabled());
mergeBuilder.withSpecialWordFilter(defaultOptions.getSpecialWordFilter() != null
? defaultOptions.getSpecialWordFilter()
: target.getSpecialWordFilter());
mergeBuilder.withAudioEventDetectionEnabled(
defaultOptions.getAudioEventDetectionEnabled() != null
? defaultOptions.getAudioEventDetectionEnabled()
: target.getAudioEventDetectionEnabled());
return mergeBuilder.build();
}
public TranscriptionParam toTranscriptionParam(
TongYiAudioTranscriptionOptions source) {
var mergeBuilder = TranscriptionParam.builder();
mergeBuilder.model(source.getModel() != null ? source.getModel()
: AudioTranscriptionModels.Paraformer_V1);
mergeBuilder.fileUrls(
source.getFileUrls() != null ? source.getFileUrls() : new ArrayList<>());
if (source.getPhraseId() != null) {
mergeBuilder.phraseId(source.getPhraseId());
}
if (source.getChannelId() != null) {
mergeBuilder.channelId(source.getChannelId());
}
if (source.getDiarizationEnabled() != null) {
mergeBuilder.diarizationEnabled(source.getDiarizationEnabled());
}
if (source.getSpeakerCount() != null) {
mergeBuilder.speakerCount(source.getSpeakerCount());
}
if (source.getDisfluencyRemovalEnabled() != null) {
mergeBuilder.disfluencyRemovalEnabled(source.getDisfluencyRemovalEnabled());
}
if (source.getTimestampAlignmentEnabled() != null) {
mergeBuilder.timestampAlignmentEnabled(source.getTimestampAlignmentEnabled());
}
if (source.getSpecialWordFilter() != null) {
mergeBuilder.specialWordFilter(source.getSpecialWordFilter());
}
if (source.getAudioEventDetectionEnabled() != null) {
mergeBuilder
.audioEventDetectionEnabled(source.getAudioEventDetectionEnabled());
}
return mergeBuilder.build();
}
}

View File

@ -0,0 +1,203 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import org.springframework.ai.model.ModelOptions;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionOptions implements ModelOptions {
private String model = AudioTranscriptionModels.Paraformer_V1;
private List<String> fileUrls = new ArrayList<>();
private String phraseId = null;
private List<Integer> channelId = Collections.singletonList(0);
private Boolean diarizationEnabled = false;
private Integer speakerCount = null;
private Boolean disfluencyRemovalEnabled = false;
private Boolean timestampAlignmentEnabled = false;
private String specialWordFilter = "";
private Boolean audioEventDetectionEnabled = false;
public static Builder builder() {
return new Builder();
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public List<String> getFileUrls() {
return fileUrls;
}
public void setFileUrls(List<String> fileUrls) {
this.fileUrls = fileUrls;
}
public String getPhraseId() {
return phraseId;
}
public void setPhraseId(String phraseId) {
this.phraseId = phraseId;
}
public List<Integer> getChannelId() {
return channelId;
}
public void setChannelId(List<Integer> channelId) {
this.channelId = channelId;
}
public Boolean getDiarizationEnabled() {
return diarizationEnabled;
}
public void setDiarizationEnabled(Boolean diarizationEnabled) {
this.diarizationEnabled = diarizationEnabled;
}
public Integer getSpeakerCount() {
return speakerCount;
}
public void setSpeakerCount(Integer speakerCount) {
this.speakerCount = speakerCount;
}
public Boolean getDisfluencyRemovalEnabled() {
return disfluencyRemovalEnabled;
}
public void setDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
this.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
}
public Boolean getTimestampAlignmentEnabled() {
return timestampAlignmentEnabled;
}
public void setTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
this.timestampAlignmentEnabled = timestampAlignmentEnabled;
}
public String getSpecialWordFilter() {
return specialWordFilter;
}
public void setSpecialWordFilter(String specialWordFilter) {
this.specialWordFilter = specialWordFilter;
}
public Boolean getAudioEventDetectionEnabled() {
return audioEventDetectionEnabled;
}
public void setAudioEventDetectionEnabled(Boolean audioEventDetectionEnabled) {
this.audioEventDetectionEnabled = audioEventDetectionEnabled;
}
/**
* Builder class for constructing TongYiAudioTranscriptionOptions instances.
*/
public static class Builder {
private final TongYiAudioTranscriptionOptions options = new TongYiAudioTranscriptionOptions();
public Builder withModel(String model) {
options.model = model;
return this;
}
public Builder withFileUrls(List<String> fileUrls) {
options.fileUrls = fileUrls;
return this;
}
public Builder withPhraseId(String phraseId) {
options.phraseId = phraseId;
return this;
}
public Builder withChannelId(List<Integer> channelId) {
options.channelId = channelId;
return this;
}
public Builder withDiarizationEnabled(Boolean diarizationEnabled) {
options.diarizationEnabled = diarizationEnabled;
return this;
}
public Builder withSpeakerCount(Integer speakerCount) {
options.speakerCount = speakerCount;
return this;
}
public Builder withDisfluencyRemovalEnabled(Boolean disfluencyRemovalEnabled) {
options.disfluencyRemovalEnabled = disfluencyRemovalEnabled;
return this;
}
public Builder withTimestampAlignmentEnabled(Boolean timestampAlignmentEnabled) {
options.timestampAlignmentEnabled = timestampAlignmentEnabled;
return this;
}
public Builder withSpecialWordFilter(String specialWordFilter) {
options.specialWordFilter = specialWordFilter;
return this;
}
public Builder withAudioEventDetectionEnabled(
Boolean audioEventDetectionEnabled) {
options.audioEventDetectionEnabled = audioEventDetectionEnabled;
return this;
}
public TongYiAudioTranscriptionOptions build() {
// Perform any necessary validation here before returning the built object
return options;
}
}
}

View File

@ -0,0 +1,72 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription;
import com.alibaba.cloud.ai.tongyi.audio.AudioTranscriptionModels;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiAudioTranscriptionProperties.CONFIG_PREFIX)
public class TongYiAudioTranscriptionProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "audio.transcription";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_AUDIO_MODEL_NAME = AudioTranscriptionModels.Paraformer_V1;
/**
* Enable TongYiQWEN ai audio client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiAudioTranscriptionOptions options = TongYiAudioTranscriptionOptions
.builder().withModel(DEFAULT_AUDIO_MODEL_NAME).build();
public TongYiAudioTranscriptionOptions getOptions() {
return this.options;
}
public void setOptions(TongYiAudioTranscriptionOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@ -0,0 +1,56 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.audio.transcription.TongYiAudioTranscriptionOptions;
import org.springframework.ai.model.ModelRequest;
import org.springframework.core.io.Resource;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class AudioTranscriptionPrompt implements ModelRequest<Resource> {
private Resource audioResource;
private TongYiAudioTranscriptionOptions transcriptionOptions;
public AudioTranscriptionPrompt(Resource resource) {
this.audioResource = resource;
}
public AudioTranscriptionPrompt(Resource resource, TongYiAudioTranscriptionOptions transcriptionOptions) {
this.audioResource = resource;
this.transcriptionOptions = transcriptionOptions;
}
@Override
public Resource getInstructions() {
return audioResource;
}
@Override
public TongYiAudioTranscriptionOptions getOptions() {
return transcriptionOptions;
}
}

View File

@ -0,0 +1,67 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import org.springframework.ai.model.ResponseMetadata;
import java.util.List;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public class AudioTranscriptionResponse implements ModelResponse<AudioTranscriptionResult> {
private List<AudioTranscriptionResult> resultList;
private TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata;
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result) {
this(result, TongYiAudioTranscriptionResponseMetadata.NULL);
}
public AudioTranscriptionResponse(List<AudioTranscriptionResult> result,
TongYiAudioTranscriptionResponseMetadata transcriptionResponseMetadata) {
this.resultList = List.copyOf(result);
this.transcriptionResponseMetadata = transcriptionResponseMetadata;
}
@Override
public AudioTranscriptionResult getResult() {
return this.resultList.get(0);
}
@Override
public List<AudioTranscriptionResult> getResults() {
return this.resultList;
}
@Override
public ResponseMetadata getMetadata() {
return this.transcriptionResponseMetadata;
}
}

View File

@ -0,0 +1,68 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.audio.transcription.api;
import com.alibaba.cloud.ai.tongyi.metadata.audio.TongYiAudioTranscriptionMetadata;
import org.springframework.ai.model.ModelResult;
import org.springframework.ai.model.ResultMetadata;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class AudioTranscriptionResult implements ModelResult<String> {
private String text;
private TongYiAudioTranscriptionMetadata transcriptionMetadata;
public AudioTranscriptionResult(String text) {
this.text = text;
}
@Override
public String getOutput() {
return this.text;
}
@Override
public ResultMetadata getMetadata() {
return transcriptionMetadata != null ? transcriptionMetadata : TongYiAudioTranscriptionMetadata.NULL;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
AudioTranscriptionResult that = (AudioTranscriptionResult) o;
return Objects.equals(text, that.text) && Objects.equals(transcriptionMetadata, that.transcriptionMetadata);
}
@Override
public int hashCode() {
return Objects.hash(text, transcriptionMetadata);
}
}

View File

@ -0,0 +1,481 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.dashscope.aigc.conversation.ConversationParam;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationOutput;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import com.alibaba.dashscope.tools.FunctionDefinition;
import com.alibaba.dashscope.tools.ToolCallBase;
import com.alibaba.dashscope.tools.ToolCallFunction;
import com.alibaba.dashscope.utils.ApiKeywords;
import com.alibaba.dashscope.utils.JsonUtils;
import io.reactivex.Flowable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.StreamingChatModel;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.scheduler.Schedulers;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
/**
* {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal Alibaba DashScope}
* backed by {@link Generation}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
* @see ChatModel
* @see com.alibaba.dashscope.aigc.generation
*/
public class TongYiChatModel extends
AbstractFunctionCallSupport<
com.alibaba.dashscope.common.Message,
ConversationParam,
GenerationResult>
implements ChatModel, StreamingChatModel {
private static final Logger logger = LoggerFactory.getLogger(TongYiChatModel.class);
/**
* DashScope generation client.
*/
private final Generation generation;
/**
* The TongYi models default chat completion api.
*/
private TongYiChatOptions defaultOptions;
/**
* User role message manager.
*/
@Autowired
private MessageManager msgManager;
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
*/
public TongYiChatModel(Generation generation) {
this(generation,
TongYiChatOptions.builder()
.withTopP(0.8)
.withEnableSearch(true)
.withResultFormat(ConversationParam.ResultFormat.MESSAGE)
.build(),
null
);
}
/**
* Initializes an instance of the TongYiChatClient.
* @param generation DashScope generation client.
* @param options TongYi model params.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options) {
this(generation, options, null);
}
/**
* Create a TongYi models client.
* @param generation DashScope model generation client.
* @param options TongYi default chat completion api.
*/
public TongYiChatModel(Generation generation, TongYiChatOptions options,
FunctionCallbackContext functionCallbackContext) {
super(functionCallbackContext);
this.generation = generation;
this.defaultOptions = options;
}
/**
* Get default sca chat options.
*
* @return TongYiChatOptions default object.
*/
public TongYiChatOptions getDefaultOptions() {
return this.defaultOptions;
}
@Override
public ChatResponse call(Prompt prompt) {
ConversationParam params = toTongYiChatParams(prompt);
// TongYi models context loader.
com.alibaba.dashscope.common.Message message = new com.alibaba.dashscope.common.Message();
message.setRole(Role.USER.getValue());
message.setContent(prompt.getContents());
msgManager.add(message);
params.setMessages(msgManager.get());
logger.trace("TongYi ConversationOptions: {}", params);
GenerationResult chatCompletions = this.callWithFunctionSupport(params);
logger.trace("TongYi ConversationOptions: {}", params);
msgManager.add(chatCompletions);
List<org.springframework.ai.chat.model.Generation> generations =
chatCompletions
.getOutput()
.getChoices()
.stream()
.map(choice ->
new org.springframework.ai.chat.model.Generation(
choice
.getMessage()
.getContent()
).withGenerationMetadata(generateChoiceMetadata(choice)
))
.toList();
return new ChatResponse(generations);
}
@Override
public Flux<ChatResponse> stream(Prompt prompt) {
Flowable<GenerationResult> genRes;
ConversationParam tongYiChatParams = toTongYiChatParams(prompt);
// See https://help.aliyun.com/zh/dashscope/developer-reference/api-details?spm=a2c4g.11186623.0.0.655fc11aRR0jj7#b9ad0a10cfhpe
// tongYiChatParams.setIncrementalOutput(true);
try {
genRes = generation.streamCall(tongYiChatParams);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes)
.flatMap(
message -> Flux.just(
message.getOutput()
.getChoices()
.get(0)
.getMessage()
.getContent())
.map(content -> {
var gen = new org.springframework.ai.chat.model.Generation(content)
.withGenerationMetadata(generateChoiceMetadata(
message.getOutput()
.getChoices()
.get(0)
));
return new ChatResponse(List.of(gen));
})
)
.publishOn(Schedulers.parallel());
}
/**
* Configuration properties to Qwen model params.
* Test access.
*
* @param prompt {@link Prompt}
* @return Qwen models params {@link ConversationParam}
*/
public ConversationParam toTongYiChatParams(Prompt prompt) {
Set<String> functionsForThisRequest = new HashSet<>();
List<com.alibaba.dashscope.common.Message> tongYiMessage = prompt.getInstructions().stream()
.map(this::fromSpringAIMessage)
.toList();
ConversationParam chatParams = ConversationParam.builder()
.messages(tongYiMessage)
// models setting
// {@link HalfDuplexServiceParam#models}
.model(Generation.Models.QWEN_TURBO)
// {@link GenerationOutput}
.resultFormat(ConversationParam.ResultFormat.MESSAGE)
.build();
if (this.defaultOptions != null) {
chatParams = merge(chatParams, this.defaultOptions);
Set<String> defaultEnabledFunctions = this.handleFunctionCallbackConfigurations(this.defaultOptions, !IS_RUNTIME_CALL);
functionsForThisRequest.addAll(defaultEnabledFunctions);
}
if (prompt.getOptions() != null) {
if (prompt.getOptions() instanceof ChatOptions runtimeOptions) {
TongYiChatOptions updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(runtimeOptions,
ChatOptions.class, TongYiChatOptions.class);
chatParams = merge(updatedRuntimeOptions, chatParams);
Set<String> promptEnabledFunctions = this.handleFunctionCallbackConfigurations(updatedRuntimeOptions,
IS_RUNTIME_CALL);
functionsForThisRequest.addAll(promptEnabledFunctions);
}
else {
throw new IllegalArgumentException("Prompt options are not of type ConversationParam:"
+ prompt.getOptions().getClass().getSimpleName());
}
}
// Add the enabled functions definitions to the request's tools parameter.
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<FunctionDefinition> tools = this.getFunctionTools(functionsForThisRequest);
// todo chatParams.setTools(tools)
}
return chatParams;
}
private ChatGenerationMetadata generateChoiceMetadata(GenerationOutput.Choice choice) {
return ChatGenerationMetadata.from(
String.valueOf(choice.getFinishReason()),
choice.getMessage().getContent()
);
}
private List<FunctionDefinition> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
FunctionDefinition functionDefinition = FunctionDefinition.builder()
.name(functionCallback.getName())
.description(functionCallback.getDescription())
.parameters(JsonUtils.parametersToJsonObject(
ModelOptionsUtils.jsonToMap(functionCallback.getInputTypeSchema())
))
.build();
return functionDefinition;
}).toList();
}
private ConversationParam merge(ConversationParam tongYiParams, TongYiChatOptions scaChatParams) {
if (scaChatParams == null) {
return tongYiParams;
}
return ConversationParam.builder()
.messages(tongYiParams.getMessages())
.maxTokens((tongYiParams.getMaxTokens() != null) ? tongYiParams.getMaxTokens() : scaChatParams.getMaxTokens())
// When merge options. Because ConversationParams is must not null. So is setting.
.model(scaChatParams.getModel())
.resultFormat((tongYiParams.getResultFormat() != null) ? tongYiParams.getResultFormat() : scaChatParams.getResultFormat())
.enableSearch((tongYiParams.getEnableSearch() != null) ? tongYiParams.getEnableSearch() : scaChatParams.getEnableSearch())
.topK((tongYiParams.getTopK() != null) ? tongYiParams.getTopK() : scaChatParams.getTopK())
.topP((tongYiParams.getTopP() != null) ? tongYiParams.getTopP() : scaChatParams.getTopP())
.incrementalOutput((tongYiParams.getIncrementalOutput() != null) ? tongYiParams.getIncrementalOutput() : scaChatParams.getIncrementalOutput())
.temperature((tongYiParams.getTemperature() != null) ? tongYiParams.getTemperature() : scaChatParams.getTemperature())
.repetitionPenalty((tongYiParams.getRepetitionPenalty() != null) ? tongYiParams.getRepetitionPenalty() : scaChatParams.getRepetitionPenalty())
.seed((tongYiParams.getSeed() != null) ? tongYiParams.getSeed() : scaChatParams.getSeed())
.build();
}
private ConversationParam merge(TongYiChatOptions scaChatParams, ConversationParam tongYiParams) {
if (scaChatParams == null) {
return tongYiParams;
}
ConversationParam mergedTongYiParams = ConversationParam.builder()
.model(Generation.Models.QWEN_TURBO)
.messages(tongYiParams.getMessages())
.build();
mergedTongYiParams = merge(tongYiParams, scaChatParams);
if (scaChatParams.getMaxTokens() != null) {
mergedTongYiParams.setMaxTokens(scaChatParams.getMaxTokens());
}
if (scaChatParams.getStop() != null) {
mergedTongYiParams.setStopStrings(scaChatParams.getStop());
}
if (scaChatParams.getTemperature() != null) {
mergedTongYiParams.setTemperature(scaChatParams.getTemperature());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
if (scaChatParams.getTopK() != null) {
mergedTongYiParams.setTopK(scaChatParams.getTopK());
}
return mergedTongYiParams;
}
private com.alibaba.dashscope.common.Message fromSpringAIMessage(Message message) {
return switch (message.getMessageType()) {
case USER -> com.alibaba.dashscope.common.Message.builder()
.role(Role.USER.getValue())
.content(message.getContent())
.build();
case SYSTEM -> com.alibaba.dashscope.common.Message.builder()
.role(Role.SYSTEM.getValue())
.content(message.getContent())
.build();
case ASSISTANT -> com.alibaba.dashscope.common.Message.builder()
.role(Role.ASSISTANT.getValue())
.content(message.getContent())
.build();
default -> throw new IllegalArgumentException("Unknown message type " + message.getMessageType());
};
}
@Override
protected ConversationParam doCreateToolResponseRequest(
ConversationParam previousRequest,
com.alibaba.dashscope.common.Message responseMessage,
List<com.alibaba.dashscope.common.Message> conversationHistory
) {
for (ToolCallBase toolCall : responseMessage.getToolCalls()) {
if (toolCall instanceof ToolCallFunction toolCallFunction) {
if (toolCallFunction.getFunction() != null) {
var functionName = toolCallFunction.getFunction().getName();
var functionArguments = toolCallFunction.getFunction().getArguments();
if (!this.functionCallbackRegister.containsKey(functionName)) {
throw new IllegalStateException("No function callback found for function name: " + functionName);
}
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
// Add the function response to the conversation.
conversationHistory
.add(com.alibaba.dashscope.common.Message.builder()
.content(functionResponse)
.role(Role.BOT.getValue())
.toolCallId(toolCall.getId())
.build()
);
}
}
}
ConversationParam newRequest = ConversationParam.builder().messages(conversationHistory).build();
// todo: No @JsonProperty fields.
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ConversationParam.class);
return newRequest;
}
@Override
protected List<com.alibaba.dashscope.common.Message> doGetUserMessages(ConversationParam request) {
return request.getMessages();
}
@Override
protected com.alibaba.dashscope.common.Message doGetToolResponseMessage(GenerationResult response) {
var message = response.getOutput().getChoices().get(0).getMessage();
var assistantMessage = com.alibaba.dashscope.common.Message.builder().role(Role.ASSISTANT.getValue())
.content("").build();
assistantMessage.setToolCalls(message.getToolCalls());
return assistantMessage;
}
@Override
protected GenerationResult doChatCompletion(ConversationParam request) {
GenerationResult result;
try {
result = generation.call(request);
}
catch (NoApiKeyException | InputRequiredException e) {
throw new RuntimeException(e);
}
return result;
}
@Override
protected Flux<GenerationResult> doChatCompletionStream(ConversationParam request) {
final Flowable<GenerationResult> genRes;
try {
genRes = generation.streamCall(request);
}
catch (NoApiKeyException | InputRequiredException e) {
logger.warn("TongYi chat client: " + e.getMessage());
throw new TongYiException(e.getMessage());
}
return Flux.from(genRes);
}
@Override
protected boolean isToolFunctionCall(GenerationResult response) {
if (response == null || CollectionUtils.isEmpty(response.getOutput().getChoices())) {
return false;
}
var choice = response.getOutput().getChoices().get(0);
if (choice == null || choice.getFinishReason() == null) {
return false;
}
return Objects.equals(choice.getFinishReason(), ApiKeywords.TOOL_CALLS);
}
}

View File

@ -0,0 +1,463 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
import java.util.*;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiChatOptions implements FunctionCallingOptions, ChatOptions {
/**
* TongYi Models.
* {@link Generation.Models}
*/
private String model = Generation.Models.QWEN_TURBO;
/**
* The random number seed used in generation, the user controls the randomness of the content generated by the model.
* seed supports unsigned 64-bit integers, with a default value of 1234.
* when using seed, the model will generate the same or similar results as much as possible, but there is currently no guarantee that the results will be exactly the same each time.
*/
private Integer seed = 1234;
/**
* Used to specify the maximum number of tokens that the model can generate when generating content,
* it defines the upper limit of generation but does not guarantee that this number will be generated every time.
* For qwen-turbo the maximum and default values are 1500 tokens.
* The qwen-max, qwen-max-1201, qwen-max-longcontext, and qwen-plus models have a maximum and default value of 2000 tokens.
*/
private Integer maxTokens = 1500;
/**
* The generation process kernel sampling method probability threshold,
* for example, takes the value of 0.8, only retains the smallest set of the most probable tokens with probabilities that add up to greater than or equal to 0.8 as the candidate set.
* The range of values is (0,1.0), the larger the value, the higher the randomness of generation; the lower the value, the higher the certainty of generation.
*/
private Double topP = 0.8;
/**
* The size of the sampling candidate set at the time of generation.
* For example, with a value of 50, only the 50 highest scoring tokens in a single generation will form a randomly sampled candidate set.
* The larger the value, the higher the randomness of the generation; the smaller the value, the higher the certainty of the generation.
* This parameter is not passed by default, and a value of None or when top_k is greater than 100 indicates that the top_k policy is not enabled,
* at which time, only the top_p policy is in effect.
*/
private Integer topK;
/**
* Used to control the repeatability of model generation.
* Increasing repetition_penalty reduces the repetition of model generation. 1.0 means no penalty.
*/
private Double repetitionPenalty = 1.1;
/**
* is used to control the degree of randomness and diversity.
* Specifically, the temperature value controls the extent to which the probability distribution of each candidate word is smoothed when generating text.
* Higher values of temperature reduce the peak of the probability distribution, allowing more low-probability words to be selected and generating more diverse results,
* while lower values of temperature enhance the peak of the probability distribution, making it easier for high-probability words to be selected and generating more certain results.
* Range: [0, 2), 0 is not recommended, meaningless.
* java version >= 2.5.1
*/
private Double temperature = 0.85;
/**
* The stop parameter is used to realize precise control of the content generation process, automatically stopping when the generated content is about to contain the specified string or token_ids,
* and the generated content does not contain the specified content.
* For example, if stop is specified as "Hello", it means stop when "Hello" will be generated; if stop is specified as [37763, 367], it means stop when "Observation" will be generated.
* The stop parameter can be passed as a list of arrays of strings or token_ids to support the scenario of using multiple stops.
* Explanation: Do not mix strings and token_ids in list mode, the element types should be the same in list mode.
*/
private List<String> stop;
/**
* Whether or not to use stream output. When outputting the result in stream mode, the interface returns the result as generator,
* you need to iterate to get the result, the default output is the whole sequence of the current generation for each output,
* the last output is the final result of all the generation, you can change the output mode to non-incremental output by the parameter incremental_output to False.
*/
private Boolean stream = false;
/**
* The model has a built-in Internet search service.
* This parameter controls whether the model refers to the use of Internet search results when generating text. The values are as follows:
* True: enable internet search, the model will use the search result as the reference information in the text generation process, but the model will "judge by itself" whether to use the internet search result based on its internal logic.
* False (default): Internet search is disabled.
*/
private Boolean enableSearch = false;
/**
* [text|message], defaults to text, when it is message,
* the output refers to the message result example.
* It is recommended to prioritize the use of message format.
*/
private String resultFormat = GenerationParam.ResultFormat.MESSAGE;
/**
* Control the streaming output mode, that is, the content will contain the content has been output;
* set to True, will open the incremental output mode, the output will not contain the content has been output,
* you need to splice the whole output, refer to the streaming output sample code.
*/
private Boolean incrementalOutput = false;
/**
* A list of tools that the model can optionally call.
* Currently only functions are supported, and even if multiple functions are entered, the model will only select one to generate the result.
*/
private List<String> tools;
@Override
public Float getTemperature() {
return this.temperature.floatValue();
}
public void setTemperature(Float temperature) {
this.temperature = temperature.doubleValue();
}
@Override
public Float getTopP() {
return this.topP.floatValue();
}
public void setTopP(Float topP) {
this.topP = topP.doubleValue();
}
@Override
public Integer getTopK() {
return this.topK;
}
public void setTopK(Integer topK) {
this.topK = topK;
}
public String getModel() {
return model;
}
public void setModel(String model) {
this.model = model;
}
public Integer getSeed() {
return seed;
}
public String getResultFormat() {
return resultFormat;
}
public void setResultFormat(String resultFormat) {
this.resultFormat = resultFormat;
}
public void setSeed(Integer seed) {
this.seed = seed;
}
public Integer getMaxTokens() {
return maxTokens;
}
public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
public Float getRepetitionPenalty() {
return repetitionPenalty.floatValue();
}
public void setRepetitionPenalty(Float repetitionPenalty) {
this.repetitionPenalty = repetitionPenalty.doubleValue();
}
public List<String> getStop() {
return stop;
}
public void setStop(List<String> stop) {
this.stop = stop;
}
public Boolean getStream() {
return stream;
}
public void setStream(Boolean stream) {
this.stream = stream;
}
public Boolean getEnableSearch() {
return enableSearch;
}
public void setEnableSearch(Boolean enableSearch) {
this.enableSearch = enableSearch;
}
public Boolean getIncrementalOutput() {
return incrementalOutput;
}
public void setIncrementalOutput(Boolean incrementalOutput) {
this.incrementalOutput = incrementalOutput;
}
public List<String> getTools() {
return tools;
}
public void setTools(List<String> tools) {
this.tools = tools;
}
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private Set<String> functions = new HashSet<>();
@Override
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
}
@Override
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.functionCallbacks = functionCallbacks;
}
@Override
public Set<String> getFunctions() {
return this.functions;
}
@Override
public void setFunctions(Set<String> functions) {
this.functions = functions;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiChatOptions that = (TongYiChatOptions) o;
return Objects.equals(model, that.model)
&& Objects.equals(seed, that.seed)
&& Objects.equals(maxTokens, that.maxTokens)
&& Objects.equals(topP, that.topP)
&& Objects.equals(topK, that.topK)
&& Objects.equals(repetitionPenalty, that.repetitionPenalty)
&& Objects.equals(temperature, that.temperature)
&& Objects.equals(stop, that.stop)
&& Objects.equals(stream, that.stream)
&& Objects.equals(enableSearch, that.enableSearch)
&& Objects.equals(resultFormat, that.resultFormat)
&& Objects.equals(incrementalOutput, that.incrementalOutput)
&& Objects.equals(tools, that.tools)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(functions, that.functions);
}
@Override
public int hashCode() {
return Objects.hash(
model,
seed,
maxTokens,
topP,
topK,
repetitionPenalty,
temperature,
stop,
stream,
enableSearch,
resultFormat,
incrementalOutput,
tools,
functionCallbacks,
functions
);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("TongYiChatOptions{");
sb.append(", model='").append(model).append('\'');
sb.append(", seed=").append(seed);
sb.append(", maxTokens=").append(maxTokens);
sb.append(", topP=").append(topP);
sb.append(", topK=").append(topK);
sb.append(", repetitionPenalty=").append(repetitionPenalty);
sb.append(", temperature=").append(temperature);
sb.append(", stop=").append(stop);
sb.append(", stream=").append(stream);
sb.append(", enableSearch=").append(enableSearch);
sb.append(", resultFormat='").append(resultFormat).append('\'');
sb.append(", incrementalOutput=").append(incrementalOutput);
sb.append(", tools=").append(tools);
sb.append(", functionCallbacks=").append(functionCallbacks);
sb.append(", functions=").append(functions);
sb.append('}');
return sb.toString();
}
public static Builder builder() {
return new Builder();
}
public static class Builder {
protected TongYiChatOptions options;
public Builder() {
this.options = new TongYiChatOptions();
}
public Builder(TongYiChatOptions options) {
this.options = options;
}
public Builder withModel(String model) {
this.options.model = model;
return this;
}
public Builder withMaxTokens(Integer maxTokens) {
this.options.maxTokens = maxTokens;
return this;
}
public Builder withResultFormat(String rf) {
this.options.resultFormat = rf;
return this;
}
public Builder withEnableSearch(Boolean enableSearch) {
this.options.enableSearch = enableSearch;
return this;
}
public Builder withFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
return this;
}
public Builder withFunctions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
return this;
}
public Builder withFunction(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
return this;
}
public Builder withSeed(Integer seed) {
this.options.seed = seed;
return this;
}
public Builder withStop(List<String> stop) {
this.options.stop = stop;
return this;
}
public Builder withTemperature(Double temperature) {
this.options.temperature = temperature;
return this;
}
public Builder withTopP(Double topP) {
this.options.topP = topP;
return this;
}
public Builder withTopK(Integer topK) {
this.options.topK = topK;
return this;
}
public Builder withRepetitionPenalty(Double repetitionPenalty) {
this.options.repetitionPenalty = repetitionPenalty;
return this;
}
public TongYiChatOptions build() {
return this.options;
}
}
}

View File

@ -0,0 +1,83 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.chat;
import com.alibaba.dashscope.aigc.generation.Generation;
import com.alibaba.dashscope.aigc.generation.GenerationParam;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiChatProperties.CONFIG_PREFIX)
public class TongYiChatProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "chat";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_DEPLOYMENT_NAME = Generation.Models.QWEN_TURBO;
/**
* Default temperature speed.
*/
private static final Double DEFAULT_TEMPERATURE = 0.8;
/**
* Enable TongYiQWEN ai chat client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiChatOptions options = TongYiChatOptions.builder()
.withModel(DEFAULT_DEPLOYMENT_NAME)
.withTemperature(DEFAULT_TEMPERATURE)
.withEnableSearch(true)
.withResultFormat(GenerationParam.ResultFormat.MESSAGE)
.build();
public TongYiChatOptions getOptions() {
return this.options;
}
public void setOptions(TongYiChatOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@ -0,0 +1,44 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.constants;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
*/
public final class TongYiConstants {
private TongYiConstants() {
}
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String SCA_AI_CONFIGURATION = "spring.cloud.ai.tongyi.";
/**
* Spring Cloud Alibaba AI constants prefix.
*/
public static final String SCA_AI = "SPRING_CLOUD_ALIBABA_";
/**
* TongYi AI apikey env name.
*/
public static final String SCA_AI_TONGYI_API_KEY = SCA_AI + "TONGYI_API_KEY";
}

View File

@ -0,0 +1,38 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.exception;
/**
* TongYi models runtime exception.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiException extends RuntimeException {
public TongYiException(String message) {
super(message);
}
public TongYiException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,39 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.common.exception;
/**
* TongYi models images exception.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesException extends TongYiException {
public TongYiImagesException(String message) {
super(message);
}
public TongYiImagesException(String message, Throwable cause) {
super(message, cause);
}
}

View File

@ -0,0 +1,84 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import org.springframework.ai.embedding.EmbeddingOptions;
import java.util.List;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public final class TongYiEmbeddingOptions implements EmbeddingOptions {
private List<String> texts;
private TextEmbeddingParam.TextType textType;
public List<String> getTexts() {
return texts;
}
public void setTexts(List<String> texts) {
this.texts = texts;
}
public TextEmbeddingParam.TextType getTextType() {
return textType;
}
public void setTextType(TextEmbeddingParam.TextType textType) {
this.textType = textType;
}
public static Builder builder() {
return new Builder();
}
public final static class Builder {
private final TongYiEmbeddingOptions options;
private Builder() {
this.options = new TongYiEmbeddingOptions();
}
public Builder withtexts(List<String> texts) {
options.setTexts(texts);
return this;
}
public Builder withtextType(TextEmbeddingParam.TextType textType) {
options.setTextType(textType);
return this;
}
public TongYiEmbeddingOptions build() {
return options;
}
}
}

View File

@ -0,0 +1,175 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiException;
import com.alibaba.cloud.ai.tongyi.metadata.TongYiTextEmbeddingResponseMetadata;
import com.alibaba.dashscope.embeddings.TextEmbedding;
import com.alibaba.dashscope.embeddings.TextEmbeddingParam;
import com.alibaba.dashscope.embeddings.TextEmbeddingResult;
import com.alibaba.dashscope.embeddings.TextEmbeddingResultItem;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.document.MetadataMode;
import org.springframework.ai.embedding.AbstractEmbeddingModel;
import org.springframework.ai.embedding.Embedding;
import org.springframework.ai.embedding.EmbeddingRequest;
import org.springframework.ai.embedding.EmbeddingResponse;
import org.springframework.util.Assert;
import java.util.List;
import java.util.stream.Collectors;
/**
* {@link TongYiTextEmbeddingModel} implementation for {@literal Alibaba DashScope}.
*
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public class TongYiTextEmbeddingModel extends AbstractEmbeddingModel {
private final Logger logger = LoggerFactory.getLogger(TongYiTextEmbeddingModel.class);
/**
* TongYi Text Embedding client.
*/
private final TextEmbedding textEmbedding;
/**
* {@link MetadataMode}.
*/
private final MetadataMode metadataMode;
private final TongYiEmbeddingOptions defaultOptions;
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding) {
this(textEmbedding, MetadataMode.EMBED);
}
public TongYiTextEmbeddingModel(TextEmbedding textEmbedding, MetadataMode metadataMode) {
this(textEmbedding, metadataMode,
TongYiEmbeddingOptions.builder()
.withtextType(TextEmbeddingParam.TextType.DOCUMENT)
.build()
);
}
public TongYiTextEmbeddingModel(
TextEmbedding textEmbedding,
MetadataMode metadataMode,
TongYiEmbeddingOptions options
) {
Assert.notNull(textEmbedding, "textEmbedding must not be null");
Assert.notNull(metadataMode, "Metadata mode must not be null");
Assert.notNull(options, "TongYiEmbeddingOptions must not be null");
this.metadataMode = metadataMode;
this.textEmbedding = textEmbedding;
this.defaultOptions = options;
}
public TongYiEmbeddingOptions getDefaultOptions() {
return this.defaultOptions;
}
@Override
public List<Double> embed(Document document) {
return this.call(
new EmbeddingRequest(
List.of(document.getFormattedContent(this.metadataMode)),
null)
).getResults().stream()
.map(Embedding::getOutput)
.flatMap(List::stream)
.toList();
}
@Override
public EmbeddingResponse call(EmbeddingRequest request) {
TextEmbeddingParam embeddingParams = toEmbeddingParams(request);
logger.debug("Embedding request: {}", embeddingParams);
TextEmbeddingResult resp;
try {
resp = textEmbedding.call(embeddingParams);
}
catch (NoApiKeyException e) {
throw new TongYiException(e.getMessage());
}
return genEmbeddingResp(resp);
}
private EmbeddingResponse genEmbeddingResp(TextEmbeddingResult result) {
return new EmbeddingResponse(
genEmbeddingList(result.getOutput().getEmbeddings()),
TongYiTextEmbeddingResponseMetadata.from(result.getUsage())
);
}
private List<Embedding> genEmbeddingList(List<TextEmbeddingResultItem> embeddings) {
return embeddings.stream()
.map(embedding ->
new Embedding(
embedding.getEmbedding(),
embedding.getTextIndex()
))
.collect(Collectors.toList());
}
/**
* We recommend setting the model parameters by passing the embedding parameters through the code;
* yml configuration compatibility is not considered here.
* It is not recommended that users set parameters from yml,
* as this reduces the flexibility of the configuration.
* Because the model name keeps changing, strings are used here and constants are undefined:
* Model list: <a href="https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-quick-start">Text Embedding Model List</a>
* @param requestOptions Client params. {@link EmbeddingRequest}
* @return {@link TextEmbeddingParam}
*/
private TextEmbeddingParam toEmbeddingParams(EmbeddingRequest requestOptions) {
TextEmbeddingParam tongYiEmbeddingParams = TextEmbeddingParam.builder()
.texts(requestOptions.getInstructions())
.textType(defaultOptions.getTextType() != null ? defaultOptions.getTextType() : TextEmbeddingParam.TextType.DOCUMENT)
.model("text-embedding-v1")
.build();
try {
tongYiEmbeddingParams.validate();
}
catch (InputRequiredException e) {
throw new TongYiException(e.getMessage());
}
return tongYiEmbeddingParams;
}
}

View File

@ -0,0 +1,50 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.embedding;
import org.springframework.boot.context.properties.ConfigurationProperties;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiTextEmbeddingProperties.CONFIG_PREFIX)
public class TongYiTextEmbeddingProperties {
/**
* Prefix of TongYi Text Embedding properties.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "embedding";
private boolean enabled = true;
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@ -0,0 +1,237 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisParam;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.image.*;
import org.springframework.util.Assert;
import java.io.ByteArrayOutputStream;
import java.net.URL;
import java.util.Base64;
import java.util.stream.Collectors;
import static com.alibaba.cloud.ai.tongyi.metadata.TongYiImagesResponseMetadata.from;
/**
* TongYiImagesClient is a class that implements the ImageClient interface. It provides a
* client for calling the TongYi AI image generation API.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesModel implements ImageModel {
private final Logger logger = LoggerFactory.getLogger(TongYiImagesModel.class);
/**
* Gen Images API.
*/
private final ImageSynthesis imageSynthesis;
/**
* TongYi Gen images properties.
*/
private TongYiImagesOptions defaultOptions;
/**
* Adapt TongYi images api size properties.
*/
private final String sizeConnection = "*";
/**
* Get default images options.
*
* @return Default TongYiImagesOptions.
*/
public TongYiImagesOptions getDefaultOptions() {
return this.defaultOptions;
}
/**
* TongYiImagesClient constructor.
* @param imageSynthesis the image synthesis
* {@link ImageSynthesis}
*/
public TongYiImagesModel(ImageSynthesis imageSynthesis) {
this(imageSynthesis, TongYiImagesOptions.
builder()
.withModel(ImageSynthesis.Models.WANX_V1)
.withN(1)
.build()
);
}
/**
* TongYiImagesClient constructor.
* @param imageSynthesis {@link ImageSynthesis}
* @param imagesOptions {@link TongYiImagesOptions}
*/
public TongYiImagesModel(ImageSynthesis imageSynthesis, TongYiImagesOptions imagesOptions) {
Assert.notNull(imageSynthesis, "ImageSynthesis must not be null");
Assert.notNull(imagesOptions, "TongYiImagesOptions must not be null");
this.imageSynthesis = imageSynthesis;
this.defaultOptions = imagesOptions;
}
/**
* Call the TongYi images service.
* @param imagePrompt the image prompt.
* @return the image response.
* {@link ImageSynthesis#call(ImageSynthesisParam)}
*/
@Override
public ImageResponse call(ImagePrompt imagePrompt) {
ImageSynthesisResult result;
String prompt = imagePrompt.getInstructions().get(0).getText();
var imgParams = ImageSynthesisParam.builder()
.prompt("")
.model(ImageSynthesis.Models.WANX_V1)
.build();
if (this.defaultOptions != null) {
imgParams = merge(this.defaultOptions);
}
if (imagePrompt.getOptions() != null) {
imgParams = merge(toTingYiImageOptions(imagePrompt.getOptions()));
}
imgParams.setPrompt(prompt);
try {
result = imageSynthesis.call(imgParams);
}
catch (NoApiKeyException e) {
logger.error("TongYi models gen images failed: {}.", e.getMessage());
throw new TongYiImagesException(e.getMessage());
}
return convert(result);
}
public ImageSynthesisParam merge(TongYiImagesOptions target) {
var builder = ImageSynthesisParam.builder();
builder.model(this.defaultOptions.getModel() != null ? this.defaultOptions.getModel() : target.getModel());
builder.n(this.defaultOptions.getN() != null ? this.defaultOptions.getN() : target.getN());
builder.size((this.defaultOptions.getHeight() != null && this.defaultOptions.getWidth() != null)
? this.defaultOptions.getHeight() + "*" + this.defaultOptions.getWidth()
: target.getHeight() + "*" + target.getWidth()
);
// prompt is marked non-null but is null.
builder.prompt("");
return builder.build();
}
private ImageResponse convert(ImageSynthesisResult result) {
return new ImageResponse(
result.getOutput().getResults().stream()
.flatMap(value -> value.entrySet().stream())
.map(entry -> {
String key = entry.getKey();
String value = entry.getValue();
try {
String base64Image = convertImageToBase64(value);
Image image = new Image(value, base64Image);
return new ImageGeneration(image);
}
catch (Exception e) {
throw new RuntimeException(e);
}
})
.collect(Collectors.toList()),
from(result)
);
}
public TongYiImagesOptions toTingYiImageOptions(ImageOptions runtimeImageOptions) {
var builder = TongYiImagesOptions.builder();
if (runtimeImageOptions != null) {
if (runtimeImageOptions.getN() != null) {
builder.withN(runtimeImageOptions.getN());
}
if (runtimeImageOptions.getModel() != null) {
builder.withModel(runtimeImageOptions.getModel());
}
if (runtimeImageOptions.getHeight() != null) {
builder.withHeight(runtimeImageOptions.getHeight());
}
if (runtimeImageOptions.getWidth() != null) {
builder.withWidth(runtimeImageOptions.getWidth());
}
// todo ImagesParams.
}
return builder.build();
}
/**
* Convert image to base64.
* @param imageUrl the image url.
* @return the base64 image.
* @throws Exception the exception.
*/
public String convertImageToBase64(String imageUrl) throws Exception {
var url = new URL(imageUrl);
var inputStream = url.openStream();
var outputStream = new ByteArrayOutputStream();
var buffer = new byte[4096];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
outputStream.write(buffer, 0, bytesRead);
}
var imageBytes = outputStream.toByteArray();
String base64Image = Base64.getEncoder().encodeToString(imageBytes);
inputStream.close();
outputStream.close();
return base64Image;
}
}

View File

@ -0,0 +1,187 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.cloud.ai.tongyi.common.exception.TongYiImagesException;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.ai.image.ImageOptions;
import java.util.Objects;
/**
* TongYi Image API options.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesOptions implements ImageOptions {
/**
* Specify the model name, currently only wanx-v1 is supported.
*/
private String model = ImageSynthesis.Models.WANX_V1;
/**
* Gen images number.
*/
private Integer n;
/**
* The width of the generated images.
*/
private Integer width = 1024;
/**
* The height of the generated images.
*/
private Integer height = 1024;
@Override
public Integer getN() {
return this.n;
}
@Override
public String getModel() {
return this.model;
}
@Override
public Integer getWidth() {
return this.width;
}
@Override
public Integer getHeight() {
return this.height;
}
@Override
public String getResponseFormat() {
throw new TongYiImagesException("unimplemented!");
}
public void setModel(String model) {
this.model = model;
}
public void setN(Integer n) {
this.n = n;
}
public void setWidth(Integer width) {
this.width = width;
}
public void setHeight(Integer height) {
this.height = height;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiImagesOptions that = (TongYiImagesOptions) o;
return Objects.equals(model, that.model)
&& Objects.equals(n, that.n)
&& Objects.equals(width, that.width)
&& Objects.equals(height, that.height);
}
@Override
public int hashCode() {
return Objects.hash(model, n, width, height);
}
@Override
public String toString() {
final StringBuilder sb = new StringBuilder("TongYiImagesOptions{");
sb.append("model='").append(model).append('\'');
sb.append(", n=").append(n);
sb.append(", width=").append(width);
sb.append(", height=").append(height);
sb.append('}');
return sb.toString();
}
public static Builder builder() {
return new Builder();
}
public final static class Builder {
private final TongYiImagesOptions options;
private Builder() {
this.options = new TongYiImagesOptions();
}
public Builder withN(Integer n) {
options.setN(n);
return this;
}
public Builder withModel(String model) {
options.setModel(model);
return this;
}
public Builder withWidth(Integer width) {
options.setWidth(width);
return this;
}
public Builder withHeight(Integer height) {
options.setHeight(height);
return this;
}
public TongYiImagesOptions build() {
return options;
}
}
}

View File

@ -0,0 +1,77 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.image;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesis;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import static com.alibaba.cloud.ai.tongyi.common.constants.TongYiConstants.SCA_AI_CONFIGURATION;
/**
* TongYi Image API properties.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
@ConfigurationProperties(TongYiImagesProperties.CONFIG_PREFIX)
public class TongYiImagesProperties {
/**
* Spring Cloud Alibaba AI configuration prefix.
*/
public static final String CONFIG_PREFIX = SCA_AI_CONFIGURATION + "images";
/**
* Default TongYi Chat model.
*/
public static final String DEFAULT_IMAGES_MODEL_NAME = ImageSynthesis.Models.WANX_V1;
/**
* Enable TongYiQWEN ai images client.
*/
private boolean enabled = true;
@NestedConfigurationProperty
private TongYiImagesOptions options = TongYiImagesOptions.builder()
.withModel(DEFAULT_IMAGES_MODEL_NAME)
.withN(1)
.build();
public TongYiImagesOptions getOptions() {
return this.options;
}
public void setOptions(TongYiImagesOptions options) {
this.options = options;
}
public boolean isEnabled() {
return this.enabled;
}
public void setEnabled(boolean enabled) {
this.enabled = enabled;
}
}

View File

@ -0,0 +1,89 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* {@link ChatResponseMetadata} implementation for {@literal Alibaba DashScope}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAiChatResponseMetadata extends HashMap<String, Object> implements ChatResponseMetadata {
protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, usage: %3$s, rateLimit: %4$s }";
@SuppressWarnings("all")
public static TongYiAiChatResponseMetadata from(GenerationResult chatCompletions,
PromptMetadata promptFilterMetadata) {
Assert.notNull(chatCompletions, "Alibaba ai ChatCompletions must not be null");
String id = chatCompletions.getRequestId();
TongYiAiUsage usage = TongYiAiUsage.from(chatCompletions);
return new TongYiAiChatResponseMetadata(
id,
usage,
promptFilterMetadata
);
}
private final String id;
private final Usage usage;
private final PromptMetadata promptMetadata;
protected TongYiAiChatResponseMetadata(String id, TongYiAiUsage usage, PromptMetadata promptMetadata) {
this.id = id;
this.usage = usage;
this.promptMetadata = promptMetadata;
}
public String getId() {
return this.id;
}
@Override
public Usage getUsage() {
return this.usage;
}
@Override
public PromptMetadata getPromptMetadata() {
return this.promptMetadata;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getUsage(), getRateLimit());
}
}

View File

@ -0,0 +1,81 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.GenerationUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
/**
* {@link Usage} implementation for {@literal Alibaba DashScope}.
*
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAiUsage implements Usage {
private final GenerationUsage usage;
public TongYiAiUsage(GenerationUsage usage) {
Assert.notNull(usage, "GenerationUsage must not be null");
this.usage = usage;
}
public static TongYiAiUsage from(GenerationResult chatCompletions) {
Assert.notNull(chatCompletions, "ChatCompletions must not be null");
return from(chatCompletions.getUsage());
}
public static TongYiAiUsage from(GenerationUsage usage) {
return new TongYiAiUsage(usage);
}
protected GenerationUsage getUsage() {
return this.usage;
}
@Override
public Long getPromptTokens() {
throw new UnsupportedOperationException("Unimplemented method 'getPromptTokens'");
}
@Override
public Long getGenerationTokens() {
return this.getUsage().getOutputTokens().longValue();
}
@Override
public Long getTotalTokens() {
return this.getUsage().getTotalTokens().longValue();
}
@Override
public String toString() {
return this.getUsage().toString();
}
}

View File

@ -0,0 +1,131 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisResult;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisTaskMetrics;
import com.alibaba.dashscope.aigc.imagesynthesis.ImageSynthesisUsage;
import org.springframework.ai.image.ImageResponseMetadata;
import org.springframework.util.Assert;
import java.util.HashMap;
import java.util.Objects;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiImagesResponseMetadata extends HashMap<String, Object> implements ImageResponseMetadata {
private final Long created;
private String taskId;
private ImageSynthesisTaskMetrics metrics;
private ImageSynthesisUsage usage;
public static TongYiImagesResponseMetadata from(ImageSynthesisResult synthesisResult) {
Assert.notNull(synthesisResult, "TongYiAiImageResponse must not be null");
return new TongYiImagesResponseMetadata(
System.currentTimeMillis(),
synthesisResult.getOutput().getTaskMetrics(),
synthesisResult.getOutput().getTaskId(),
synthesisResult.getUsage()
);
}
protected TongYiImagesResponseMetadata(
Long created,
ImageSynthesisTaskMetrics metrics,
String taskId,
ImageSynthesisUsage usage
) {
this.taskId = taskId;
this.metrics = metrics;
this.created = created;
this.usage = usage;
}
public ImageSynthesisUsage getUsage() {
return usage;
}
public void setUsage(ImageSynthesisUsage usage) {
this.usage = usage;
}
@Override
public Long getCreated() {
return created;
}
public String getTaskId() {
return taskId;
}
public void setTaskId(String taskId) {
this.taskId = taskId;
}
public ImageSynthesisTaskMetrics getMetrics() {
return metrics;
}
void setMetrics(ImageSynthesisTaskMetrics metrics) {
this.metrics = metrics;
}
public Long created() {
return this.created;
}
@Override
public String toString() {
return "TongYiImagesResponseMetadata {" + "created=" + created + '}';
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
TongYiImagesResponseMetadata that = (TongYiImagesResponseMetadata) o;
return Objects.equals(created, that.created)
&& Objects.equals(taskId, that.taskId)
&& Objects.equals(metrics, that.metrics);
}
@Override
public int hashCode() {
return Objects.hash(created, taskId, metrics);
}
}

View File

@ -0,0 +1,53 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata;
import com.alibaba.dashscope.embeddings.TextEmbeddingUsage;
import org.springframework.ai.embedding.EmbeddingResponseMetadata;
/**
* @author why_ohh
* @author yuluo
* @author <a href="mailto:550588941@qq.com">why_ohh</a>
* @since 2023.0.1.0
*/
public class TongYiTextEmbeddingResponseMetadata extends EmbeddingResponseMetadata {
private Integer totalTokens;
protected TongYiTextEmbeddingResponseMetadata(Integer totalTokens) {
this.totalTokens = totalTokens;
}
public static TongYiTextEmbeddingResponseMetadata from(TextEmbeddingUsage usage) {
return new TongYiTextEmbeddingResponseMetadata(usage.getTotalTokens());
}
public Integer getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(Integer totalTokens) {
this.totalTokens = totalTokens;
}
}

View File

@ -0,0 +1,133 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisResult;
import com.alibaba.dashscope.audio.tts.SpeechSynthesisUsage;
import com.alibaba.dashscope.audio.tts.timestamp.Sentence;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import java.util.HashMap;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioSpeechResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
private SpeechSynthesisUsage usage;
private String requestId;
private Sentence time;
protected static final String AI_METADATA_STRING = "{ @type: %1$s, requestsLimit: %2$s }";
/**
* NULL objects.
*/
public static final TongYiAudioSpeechResponseMetadata NULL = new TongYiAudioSpeechResponseMetadata() {
};
public static TongYiAudioSpeechResponseMetadata from(SpeechSynthesisResult result) {
Assert.notNull(result, "TongYi AI speech must not be null");
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
return speechResponseMetadata;
}
public static TongYiAudioSpeechResponseMetadata from(String result) {
Assert.notNull(result, "TongYi AI speech must not be null");
TongYiAudioSpeechResponseMetadata speechResponseMetadata = new TongYiAudioSpeechResponseMetadata();
return speechResponseMetadata;
}
@Nullable
private RateLimit rateLimit;
public TongYiAudioSpeechResponseMetadata() {
this(null);
}
public TongYiAudioSpeechResponseMetadata(@Nullable RateLimit rateLimit) {
this.rateLimit = rateLimit;
}
@Nullable
public RateLimit getRateLimit() {
RateLimit rateLimit = this.rateLimit;
return rateLimit != null ? rateLimit : new EmptyRateLimit();
}
public TongYiAudioSpeechResponseMetadata withRateLimit(RateLimit rateLimit) {
this.rateLimit = rateLimit;
return this;
}
public TongYiAudioSpeechResponseMetadata withUsage(SpeechSynthesisUsage usage) {
this.usage = usage;
return this;
}
public TongYiAudioSpeechResponseMetadata withRequestId(String id) {
this.requestId = id;
return this;
}
public TongYiAudioSpeechResponseMetadata withSentence(Sentence sentence) {
this.time = sentence;
return this;
}
public SpeechSynthesisUsage getUsage() {
return usage;
}
public String getRequestId() {
return requestId;
}
public Sentence getTime() {
return time;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
}
}

View File

@ -0,0 +1,43 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import org.springframework.ai.model.ResultMetadata;
/**
* @author xYLiu
* @author yuluo
* @since 2023.0.1.0
*/
public interface TongYiAudioTranscriptionMetadata extends ResultMetadata {
/**
* A constant instance of {@link TongYiAudioTranscriptionMetadata} that represents a null or empty metadata.
*/
TongYiAudioTranscriptionMetadata NULL = TongYiAudioTranscriptionMetadata.create();
/**
* Factory method for creating a new instance of {@link TongYiAudioTranscriptionMetadata}.
* @return a new instance of {@link TongYiAudioTranscriptionMetadata}
*/
static TongYiAudioTranscriptionMetadata create() {
return new TongYiAudioTranscriptionMetadata() {
};
}
}

View File

@ -0,0 +1,98 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.alibaba.cloud.ai.tongyi.metadata.audio;
import com.alibaba.dashscope.audio.asr.transcription.TranscriptionResult;
import com.google.gson.JsonObject;
import org.springframework.ai.chat.metadata.EmptyRateLimit;
import org.springframework.ai.chat.metadata.RateLimit;
import org.springframework.ai.model.ResponseMetadata;
import org.springframework.util.Assert;
import javax.annotation.Nullable;
import java.util.HashMap;
/**
* @author yuluo
* @author <a href="mailto:yuluo08290126@gmail.com">yuluo</a>
* @since 2023.0.1.0
*/
public class TongYiAudioTranscriptionResponseMetadata extends HashMap<String, Object> implements ResponseMetadata {
@Nullable
private RateLimit rateLimit;
private JsonObject usage;
protected static final String AI_METADATA_STRING = "{ @type: %1$s, rateLimit: %4$s }";
/**
* NULL objects.
*/
public static final TongYiAudioTranscriptionResponseMetadata NULL = new TongYiAudioTranscriptionResponseMetadata() {
};
protected TongYiAudioTranscriptionResponseMetadata() {
this(null, new JsonObject());
}
protected TongYiAudioTranscriptionResponseMetadata(JsonObject usage) {
this(null, usage);
}
protected TongYiAudioTranscriptionResponseMetadata(@Nullable RateLimit rateLimit, JsonObject usage) {
this.rateLimit = rateLimit;
this.usage = usage;
}
public static TongYiAudioTranscriptionResponseMetadata from(TranscriptionResult result) {
Assert.notNull(result, "TongYi Transcription must not be null");
return new TongYiAudioTranscriptionResponseMetadata(result.getUsage());
}
@Nullable
public RateLimit getRateLimit() {
return this.rateLimit != null ? this.rateLimit : new EmptyRateLimit();
}
public void setRateLimit(@Nullable RateLimit rateLimit) {
this.rateLimit = rateLimit;
}
public JsonObject getUsage() {
return usage;
}
public void setUsage(JsonObject usage) {
this.usage = usage;
}
@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getName(), getRateLimit());
}
}

View File

@ -1 +1 @@
cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration
cn.iocoder.yudao.framework.ai.config.YudaoAiAutoConfiguration

View File

@ -1,105 +1,105 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
import com.alibaba.dashscope.aigc.generation.GenerationResult;
import com.alibaba.dashscope.aigc.generation.models.QwenParam;
import com.alibaba.dashscope.common.Message;
import com.alibaba.dashscope.common.MessageManager;
import com.alibaba.dashscope.common.Role;
import com.alibaba.dashscope.exception.InputRequiredException;
import com.alibaba.dashscope.exception.NoApiKeyException;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
import java.util.function.Consumer;
// TODO 芋艿整理单测
/**
* author: fansili
* time: 2024/3/13 21:37
*/
public class QianWenChatClientTests {
private QianWenChatClient qianWenChatClient;
@Before
public void setup() {
QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
QianWenOptions qianWenOptions = new QianWenOptions();
qianWenOptions.setTopP(0.8F);
// qianWenOptions.setTopK(3); TODO 芋艿临时处理
// qianWenOptions.setTemperature(0.6F); TODO 芋艿临时处理
qianWenChatClient = new QianWenChatClient(
qianWenApi,
qianWenOptions
);
}
@Test
public void callTest() {
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
messages.add(new UserMessage("长沙怎么样?"));
ChatResponse call = qianWenChatClient.call(new Prompt(messages));
System.err.println(call.getResult());
}
@Test
public void streamTest() {
List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
messages.add(new UserMessage("长沙怎么样?"));
Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
flux.subscribe(new Consumer<ChatResponse>() {
@Override
public void accept(ChatResponse chatResponse) {
System.err.print(chatResponse.getResult().getOutput().getContent());
}
});
// 阻止退出
Scanner scanner = new Scanner(System.in);
scanner.nextLine();
}
@Test
public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
MessageManager msgManager = new MessageManager(10);
Message systemMsg =
Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
msgManager.add(systemMsg);
msgManager.add(userMsg);
QwenParam param =
QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
.resultFormat(QwenParam.ResultFormat.MESSAGE)
.topP(0.8)
/* set the random seed, optional, default to 1234 if not set */
.seed(100)
.apiKey("sk-Zsd81gZYg7")
.build();
GenerationResult result = gen.call(param);
System.out.println(result);
System.out.println("-----------------");
System.out.println("-----------------");
msgManager.add(result);
param.setPrompt("能否缩短一些,只讲三点");
param.setMessages(msgManager.get());
result = gen.call(param);
System.out.println(result);
}
}
//package cn.iocoder.yudao.framework.ai.chat;
//
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatClient;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenChatModal;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.QianWenOptions;
//import cn.iocoder.yudao.framework.ai.core.model.tongyi.api.QianWenApi;
//import com.alibaba.dashscope.aigc.generation.GenerationResult;
//import com.alibaba.dashscope.aigc.generation.models.QwenParam;
//import com.alibaba.dashscope.common.Message;
//import com.alibaba.dashscope.common.MessageManager;
//import com.alibaba.dashscope.common.Role;
//import com.alibaba.dashscope.exception.InputRequiredException;
//import com.alibaba.dashscope.exception.NoApiKeyException;
//import org.junit.Before;
//import org.junit.Test;
//import org.springframework.ai.chat.messages.SystemMessage;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import reactor.core.publisher.Flux;
//
//import java.util.ArrayList;
//import java.util.List;
//import java.util.Scanner;
//import java.util.function.Consumer;
//
//// TODO 芋艿整理单测
///**
// * author: fansili
// * time: 2024/3/13 21:37
// */
//public class QianWenChatClientTests {
//
// private QianWenChatClient qianWenChatClient;
//
// @Before
// public void setup() {
// QianWenApi qianWenApi = new QianWenApi("sk-Zsd81gZYg7", QianWenChatModal.QWEN_72B_CHAT);
// QianWenOptions qianWenOptions = new QianWenOptions();
// qianWenOptions.setTopP(0.8F);
//// qianWenOptions.setTopK(3); TODO 芋艿临时处理
//// qianWenOptions.setTemperature(0.6F); TODO 芋艿临时处理
// qianWenChatClient = new QianWenChatClient(
// qianWenApi,
// qianWenOptions
// );
// }
//
// @Test
// public void callTest() {
// List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的小红书文艺作者,抒写着各城市的美好文化和风景。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// ChatResponse call = qianWenChatClient.call(new Prompt(messages));
// System.err.println(call.getResult());
// }
//
// @Test
// public void streamTest() {
// List<org.springframework.ai.chat.messages.Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// Flux<ChatResponse> flux = qianWenChatClient.stream(new Prompt(messages));
// flux.subscribe(new Consumer<ChatResponse>() {
// @Override
// public void accept(ChatResponse chatResponse) {
// System.err.print(chatResponse.getResult().getOutput().getContent());
// }
// });
//
// // 阻止退出
// Scanner scanner = new Scanner(System.in);
// scanner.nextLine();
// }
//
// @Test
// public void qianwenDemoTest() throws NoApiKeyException, InputRequiredException {
// com.alibaba.dashscope.aigc.generation.Generation gen = new com.alibaba.dashscope.aigc.generation.Generation();
// MessageManager msgManager = new MessageManager(10);
// Message systemMsg =
// Message.builder().role(Role.SYSTEM.getValue()).content("You are a helpful assistant.").build();
// Message userMsg = Message.builder().role(Role.USER.getValue()).content("就当前的海洋污染的情况,写一份限塑的倡议书提纲,需要有理有据地号召大家克制地使用塑料制品").build();
// msgManager.add(systemMsg);
// msgManager.add(userMsg);
// QwenParam param =
// QwenParam.builder().model("qwen-72b-chat").messages(msgManager.get())
// .resultFormat(QwenParam.ResultFormat.MESSAGE)
// .topP(0.8)
// /* set the random seed, optional, default to 1234 if not set */
// .seed(100)
// .apiKey("sk-Zsd81gZYg7")
// .build();
// GenerationResult result = gen.call(param);
// System.out.println(result);
// System.out.println("-----------------");
// System.out.println("-----------------");
// msgManager.add(result);
// param.setPrompt("能否缩短一些,只讲三点");
// param.setMessages(msgManager.get());
// result = gen.call(param);
// System.out.println(result);
// }
//}

View File

@ -1,16 +1,16 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatClient;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoChatModel;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.XingHuoOptions;
import cn.iocoder.yudao.framework.ai.core.model.xinghuo.api.XingHuoApi;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.prompt.Prompt;
import reactor.core.publisher.Flux;
import java.util.ArrayList;

View File

@ -1,21 +1,21 @@
package cn.iocoder.yudao.framework.ai.chat;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
import org.junit.Before;
import org.junit.Test;
import reactor.core.publisher.Flux;
import java.util.ArrayList;
import java.util.List;
import java.util.Scanner;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatClient;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.YiYanChatOptions;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanApi;
//import cn.iocoder.yudao.framework.ai.core.model.yiyan.api.YiYanChatModel;
//import org.junit.Before;
//import org.junit.Test;
//import org.springframework.ai.chat.messages.Message;
//import org.springframework.ai.chat.messages.SystemMessage;
//import org.springframework.ai.chat.messages.UserMessage;
//import org.springframework.ai.chat.model.ChatResponse;
//import org.springframework.ai.chat.prompt.Prompt;
//import reactor.core.publisher.Flux;
//
//import java.util.ArrayList;
//import java.util.List;
//import java.util.Scanner;
// TODO 芋艿整理单测
/**
@ -26,49 +26,49 @@ import java.util.Scanner;
*/
public class YiYanChatTests {
private YiYanChatClient yiYanChatClient;
@Before
public void setup() {
YiYanApi yiYanApi = new YiYanApi(
"x0cuLZ7XsaTCU08vuJWO87Lg",
"R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
YiYanChatModel.ERNIE4_3_5_8K,
86400
);
YiYanChatOptions yiYanOptions = new YiYanChatOptions();
yiYanOptions.setMaxOutputTokens(2048);
yiYanOptions.setTopP(0.6f);
yiYanOptions.setTemperature(0.85f);
yiYanChatClient = new YiYanChatClient(
yiYanApi,
yiYanOptions
);
}
@Test
public void callTest() {
// tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
ChatResponse call = yiYanChatClient.call(new Prompt(messages));
System.err.println(call.getResult());
}
@Test
public void streamTest() {
List<Message> messages = new ArrayList<>();
messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
messages.add(new UserMessage("长沙怎么样?"));
Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// 阻止退出
Scanner scanner = new Scanner(System.in);
scanner.nextLine();
}
// private YiYanChatClient yiYanChatClient;
//
// @Before
// public void setup() {
// YiYanApi yiYanApi = new YiYanApi(
// "x0cuLZ7XsaTCU08vuJWO87Lg",
// "R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK",
// YiYanChatModel.ERNIE4_3_5_8K,
// 86400
// );
// YiYanChatOptions yiYanOptions = new YiYanChatOptions();
// yiYanOptions.setMaxOutputTokens(2048);
// yiYanOptions.setTopP(0.6f);
// yiYanOptions.setTemperature(0.85f);
// yiYanChatClient = new YiYanChatClient(
// yiYanApi,
// yiYanOptions
// );
// }
//
// @Test
// public void callTest() {
//
// // tip: 百度的message 有特殊规则(最后一个message为当前请求的信息前面的message为历史对话信息)
// // tip: 地址 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/jlil56u11
// List<Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// ChatResponse call = yiYanChatClient.call(new Prompt(messages));
// System.err.println(call.getResult());
// }
//
// @Test
// public void streamTest() {
// List<Message> messages = new ArrayList<>();
// messages.add(new SystemMessage("你是一个优质的文言文作者,用文言文描述着各城市的人文风景,所有问题都采用文言文回答。"));
// messages.add(new UserMessage("长沙怎么样?"));
//
// Flux<ChatResponse> fluxResponse = yiYanChatClient.stream(new Prompt(messages));
// fluxResponse.subscribe(chatResponse -> System.err.print(chatResponse.getResult().getOutput().getContent()));
// // 阻止退出
// Scanner scanner = new Scanner(System.in);
// scanner.nextLine();
// }
}

View File

@ -4,7 +4,7 @@ import org.springframework.ai.image.ImagePrompt;
import org.springframework.ai.image.ImageResponse;
import org.junit.Before;
import org.junit.Test;
import org.springframework.ai.openai.OpenAiImageClient;
import org.springframework.ai.openai.OpenAiImageModel;
import org.springframework.ai.openai.api.OpenAiImageApi;
import javax.imageio.ImageIO;
@ -23,12 +23,12 @@ import java.util.Scanner;
public class OpenAiImageClientTests {
private OpenAiImageClient openAiImageClient;
private OpenAiImageModel openAiImageClient;
@Before
public void setup() {
// 初始化 openAiImageClient
this.openAiImageClient = new OpenAiImageClient(
this.openAiImageClient = new OpenAiImageModel(
new OpenAiImageApi("")
// new OpenAiImageOptions().setResponseFormat(OpenAiImageOptions.ResponseFormatEnum.URL.getValue()) TODO 芋艿临时处理
);

View File

@ -17,7 +17,8 @@ public class SunoTests {
@Before
public void setup() {
String url = "https://suno-om0w1cy6e-status2xxs-projects.vercel.app";
String url = "https://suno-55ishh05u-status2xxs-projects.vercel.app";
// String url = "http://127.0.0.1:3001";
this.sunoApi = new SunoApi(url);
}
@ -53,5 +54,4 @@ public class SunoTests {
System.out.println(limitUsageData);
}
}

View File

@ -160,19 +160,16 @@ spring:
gemini:
project-id: 1 # TODO 芋艿:缺配置
location: 2
qianfan: # 文心一言
api-key: x0cuLZ7XsaTCU08vuJWO87Lg
secret-key: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
cloud:
ai:
tongyi: # 通义千问
tongyi:
api-key: sk-Zsd81gZYg7
yudao.ai:
yiyan:
enable: true
aiPlatform: YI_YAN # TODO @fan建议每个都独立配置属性类
max-tokens: 1500
temperature: 0.85
topP: 0.8
topK: 0
appKey: x0cuLZ7XsaTCU08vuJWO87Lg
secretKey: R9mYF9dl9KASgi5RUq0FQt3wRisSnOcK
refreshTokenSecondTime: 86400
model: ERNIE4_3_5_8K
xinghuo:
enable: true
aiPlatform: XING_HUO # TODO @fan建议每个都独立配置属性类
@ -192,6 +189,7 @@ yudao.ai:
topP: 0.8
topK: 0
api-key: sk-Zsd81gZYg7
model: QWEN_TURBO
midjourney:
enable: true
# base-url: https://api.holdai.top/mj-relax/mj
@ -200,7 +198,8 @@ yudao.ai:
notify-url: http://java.nat300.top/admin-api/ai/image/midjourney/notify
suno:
enable: true
base-url: https://suno-om0w1cy6e-status2xxs-projects.vercel.app
# base-url: https://suno-55ishh05u-status2xxs-projects.vercel.app
base-url: http://127.0.0.1:3001
--- #################### 芋道相关配置 ####################