@ -0,0 +1,730 @@
import java.io.File ;
import java.io.IOException ;
import java.lang.annotation.ElementType ;
import java.lang.annotation.Retention ;
import java.lang.annotation.RetentionPolicy ;
import java.lang.annotation.Target ;
import java.lang.reflect.Field ;
import java.net.URI ;
import java.net.http.HttpClient ;
import java.net.http.HttpRequest ;
import java.net.http.HttpResponse ;
import java.nio.charset.StandardCharsets ;
import java.nio.file.Files ;
import java.nio.file.Path ;
import java.nio.file.StandardOpenOption ;
import java.time.Duration ;
import java.time.LocalDate ;
import java.time.LocalDateTime ;
import java.time.format.DateTimeFormatter ;
import java.util.* ;
import java.util.concurrent.* ;
import java.util.function.BiConsumer ;
import java.util.function.Function ;
import java.util.regex.Matcher ;
import java.util.regex.Pattern ;
import java.util.stream.Collectors ;
/ * *
* < p >
* JDK版本必须大于或等于21 , 直接运行将生成一份bat脚本或shell脚本 , 下载JDK可以在浏览器打开链接按需下载 :
* https : / / www . azul . com / downloads / ? version = java - 21 - lts & package = jdk # zulu
* < / p >
* /
public class LLMBenchmarkTester {
public static final String SEP = " ============================================================================================= " ;
public static final Field [ ] PARAM_FIELD = ScriptParameter . class . getDeclaredFields ( ) ;
public static final Pattern CONTENT_PATTERN = Pattern . compile ( " \" content \" \\ s*: \\ s* \" ([^ \" ]*) \" " ) ;
public static void main ( String [ ] args ) throws Exception {
if ( args = = null | | args . length = = 0 ) {
createRunScript ( ) ;
} else {
ScriptParameter param = readScriptParameter ( args ) ;
printScriptParam ( param ) ;
List < ExecuteContext > executeContexts = new ArrayList < > ( ) ;
if ( param . isTestChatModel ( ) ) {
TextQuestion textQuestion = new TextQuestion ( param ) ;
List < ExecuteContext > tasks = submit (
param . modelName . split ( " , " ) ,
param . threadSize . split ( " , " ) ,
param ,
textQuestion ,
null
) ;
if ( ! tasks . isEmpty ( ) ) {
executeContexts . addAll ( tasks ) ;
}
}
if ( param . isTestVlModel ( ) ) {
ImageQuestion imageQuestion = new ImageQuestion ( param ) ;
List < ExecuteContext > tasks = submit (
param . vlModelName . split ( " , " ) ,
param . threadSize . split ( " , " ) ,
param ,
null ,
imageQuestion
) ;
if ( ! tasks . isEmpty ( ) ) {
executeContexts . addAll ( tasks ) ;
}
}
if ( ! executeContexts . isEmpty ( ) ) {
String today = LocalDate . now ( ) . format ( DateTimeFormatter . ofPattern ( " yyyyMMdd " ) ) ;
String logName = String . format ( " llm_bench_%s.log " , today ) ;
Path logPath = Path . of ( System . getProperty ( " user.dir " ) , logName ) ;
for ( ExecuteContext executeContext : executeContexts ) {
executeContext . latch . await ( ) ;
writeLog ( logPath , executeContext ) ;
executeContext . executor . shutdownNow ( ) ;
executeContext . sessionMap . clear ( ) ;
}
}
}
}
private static void writeLog ( Path logPath , ExecuteContext executeContext ) throws IOException {
Collection < HttpContext > values = executeContext . sessionMap . values ( ) ;
double avgResponse = values . stream ( ) . mapToLong ( HttpContext : : toEndMillis ) . filter ( d - > d > 0L ) . average ( ) . orElse ( 0D ) ;
long totalTime = values . stream ( ) . mapToLong ( HttpContext : : toFinishMillis ) . sum ( ) ;
long successNum = values . stream ( ) . filter ( d - > d . success ) . count ( ) ;
long maxResponse = values . stream ( ) . mapToLong ( HttpContext : : toEndMillis ) . max ( ) . orElse ( 0L ) ;
int outTextLength = values . stream ( ) . mapToInt ( d - > d . outTexts . stream ( ) . mapToInt ( s - > s ! = null ? s . length ( ) : 0 ) . sum ( ) ) . sum ( ) ;
int outTextCount = values . stream ( ) . mapToInt ( d - > d . outTexts ! = null ? d . outTexts . size ( ) : 0 ) . sum ( ) ;
String format = " " "
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
模型 : % s
并发量 : % d
问题数量 : % d
成功 : % d
首次响应最长耗时 : % d毫秒
首次响应平均耗时 : % f毫秒
一共输出 : % d字 , 共输出 % d次 , 共计耗时 : % d毫秒
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
% s
" " " ;
String msg = String . format (
format ,
executeContext . model ,
executeContext . threadSize ,
executeContext . sessionMap . size ( ) ,
successNum ,
maxResponse ,
avgResponse ,
outTextLength ,
outTextCount ,
totalTime ,
System . lineSeparator ( )
) ;
Files . writeString ( logPath , msg , StandardCharsets . UTF_8 , StandardOpenOption . CREATE , StandardOpenOption . APPEND ) ;
}
private static List < ExecuteContext > submit ( String [ ] models , String [ ] threadSizeStr , ScriptParameter param , TextQuestion textQuestion , ImageQuestion imageQuestion ) {
List < Integer > threadSizeList = Arrays . stream ( threadSizeStr ) . map ( s - > Integer . parseInt ( s . strip ( ) ) ) . toList ( ) ;
List < ExecuteContext > executeContexts = new ArrayList < > ( ) ;
for ( Integer threadSize : threadSizeList ) {
for ( String model : models ) {
if ( textQuestion ! = null ) {
executeContexts . add ( execute ( threadSize , model , param , textQuestion . getRequestParams ( model ) ) ) ;
} else if ( imageQuestion ! = null ) {
executeContexts . add ( execute ( threadSize , model , param , imageQuestion . getRequestParams ( model ) ) ) ;
}
}
}
return executeContexts ;
}
private static ExecuteContext execute ( int threadSize , String model , ScriptParameter param , List < String > requestParams ) {
CountDownLatch latch = new CountDownLatch ( requestParams . size ( ) ) ;
ConcurrentHashMap < Long , HttpContext > sessionMap = new ConcurrentHashMap < > ( requestParams . size ( ) ) ;
ExecutorService executorService = Executors . newFixedThreadPool ( threadSize ) ;
URI uri = URI . create ( param . openAiApiHost ) ;
executorService . execute ( ( ) - > {
for ( String requestBody : requestParams ) {
startHttp ( uri , param . apiKey , requestBody , latch , sessionMap ) ;
}
} ) ;
return new ExecuteContext ( model , threadSize , sessionMap , latch , executorService ) ;
}
record ExecuteContext ( String model ,
int threadSize ,
ConcurrentHashMap < Long , HttpContext > sessionMap ,
CountDownLatch latch ,
ExecutorService executor ) {
}
private static void startHttp ( URI uri , String apiKey , String requestBody , CountDownLatch latch , Map < Long , HttpContext > sessionMap ) {
HttpRequest httpRequest = HttpRequest . newBuilder ( )
. uri ( uri )
. header ( " Content-Type " , " application/json " )
. header ( " Authorization " , " Bearer " + apiKey )
. timeout ( Duration . ofSeconds ( 15 ) )
. POST ( HttpRequest . BodyPublishers . ofString ( requestBody ) )
. build ( ) ;
try ( HttpClient client = HttpClient . newHttpClient ( ) ) {
HttpContext context = new HttpContext ( ) ;
context . start = LocalDateTime . now ( ) ;
context . outTexts = new ArrayList < > ( ) ;
Flow . Subscriber < String > subscriber = createResponseFluxHandler ( context ) ;
CompletableFuture < HttpResponse < Void > > future =
client . sendAsync ( httpRequest , HttpResponse . BodyHandlers . fromLineSubscriber ( subscriber ) ) ;
handleHttpResponseFuture ( future , latch , sessionMap , context ) ;
}
}
private static class HttpContext {
long sessionId = System . nanoTime ( ) ;
LocalDateTime start ;
LocalDateTime end ;
LocalDateTime completed ;
boolean success ;
List < String > outTexts ;
public long toEndMillis ( ) {
return this . end ! = null ? Duration . between ( this . start , this . end ) . toMillis ( ) : 0L ;
}
public long toFinishMillis ( ) {
return this . completed ! = null ? Duration . between ( this . start , this . completed ) . toMillis ( ) : 0L ;
}
}
private static void handleHttpResponseFuture ( CompletableFuture < HttpResponse < Void > > future ,
CountDownLatch latch ,
Map < Long , HttpContext > sessionMap ,
HttpContext context ) {
future . whenComplete ( ( response , exception ) - > context . success = false )
. thenAccept ( response - > {
context . success = response . statusCode ( ) = = 200 ;
sessionMap . putIfAbsent ( context . sessionId , context ) ;
latch . countDown ( ) ;
} ) . exceptionally ( err - > {
context . success = false ;
sessionMap . putIfAbsent ( context . sessionId , context ) ;
latch . countDown ( ) ;
return null ;
} ) ;
}
private static Flow . Subscriber < String > createResponseFluxHandler ( HttpContext context ) {
return new Flow . Subscriber < > ( ) {
@Override
public void onSubscribe ( Flow . Subscription subscription ) {
context . end = LocalDateTime . now ( ) ;
subscription . request ( Long . MAX_VALUE ) ;
}
@Override
public void onNext ( String item ) {
if ( item ! = null & & ! item . isEmpty ( ) ) {
Matcher matcher = CONTENT_PATTERN . matcher ( item ) ;
String group ;
if ( matcher . find ( ) & & null ! = ( group = matcher . group ( 1 ) ) & & ! group . isEmpty ( ) ) {
context . outTexts . add ( group ) ;
}
}
}
@Override
public void onError ( Throwable throwable ) {
context . success = false ;
}
@Override
public void onComplete ( ) {
context . completed = LocalDateTime . now ( ) ;
/ / System . out . println ( context . outTexts ) ;
}
} ;
}
private static class ImageQuestion {
private static final Map < String , List < String > > cache = new ConcurrentHashMap < > ( ) ;
private static final String template = " " "
{
" model " : " ${model} " ,
" messages " : [
{
" role " : " user " ,
" content " : [
{
" type " : " text " ,
" text " : " 这张图片里有什么? "
} ,
{
" type " : " image_url " ,
" image_url " : {
" url " : " ${imageBase64} "
}
}
]
}
] ,
" stream " : true
}
" " " .strip();
List < String > list ;
public String getImgHead ( File file ) {
if ( file . getName ( ) . endsWith ( " png " ) ) {
return " image/png " ;
}
if ( file . getName ( ) . endsWith ( " jpg " ) | | file . getName ( ) . endsWith ( " jpeg " ) ) {
return " image/jpeg " ;
}
return null ;
}
public String tryEncodeBase64 ( File file , Path path ) {
String imgHead = getImgHead ( file ) ;
if ( imgHead = = null | | imgHead . isBlank ( ) ) {
return null ;
}
try {
return " data: " + imgHead + " ;base64, " + Base64 . getEncoder ( ) . encodeToString ( Files . readAllBytes ( path ) ) ;
} catch ( IOException e ) {
throw new RuntimeException ( e ) ;
}
}
public List < String > base64Image ( ScriptParameter parameter ) throws IOException {
try ( var files = Files . list ( Path . of ( parameter . vlImgFolder ) ) ) {
return files . map ( path - > {
File file = path . toFile ( ) ;
return tryEncodeBase64 ( file , path ) ;
} ) . filter ( Objects : : nonNull ) . toList ( ) ;
}
}
public ImageQuestion ( ScriptParameter parameter ) throws IOException {
List < String > base64List = base64Image ( parameter ) ;
int imgNum = Integer . parseInt ( parameter . imgSize ) ;
this . list = new ArrayList < > ( imgNum ) ;
do {
for ( String item : base64List ) {
this . list . add ( item ) ;
if ( this . list . size ( ) = = imgNum ) {
break ;
}
}
} while ( this . list . size ( ) < imgNum ) ;
}
public List < String > getRequestParams ( String model ) {
List < String > cacheParams = cache . get ( model ) ;
if ( cacheParams ! = null & & ! cacheParams . isEmpty ( ) ) {
return cacheParams ;
}
List < String > params = this . list . stream ( ) . map ( s - > this . toJsonParam ( model , s ) ) . toList ( ) ;
cache . put ( model , params ) ;
return params ;
}
public String toJsonParam ( String model , String imageBase64 ) {
return template . replace ( " ${model} " , model ) . replace ( " ${imageBase64} " , imageBase64 ) ;
}
}
private static class TextQuestion {
private static final Map < String , List < String > > cache = new ConcurrentHashMap < > ( ) ;
private static final String template = " " "
{
" model " : " ${model} " ,
" messages " : [
{
" role " : " user " ,
" content " : " ${prompt} "
}
] ,
" stream " : true
}
" " " .strip();
/ / 解析文件得到的问题列表
List < String > list ;
public TextQuestion ( ScriptParameter parameter ) throws IOException {
this . list = Files . readAllLines ( Path . of ( parameter . chatDatasetsPath ) ) ;
}
public List < String > getRequestParams ( String model ) {
List < String > cacheParams = cache . get ( model ) ;
if ( cacheParams ! = null & & cacheParams . isEmpty ( ) ) {
return cacheParams ;
}
List < String > params = this . list . stream ( ) . map ( s - > this . toJsonParam ( model , s ) ) . toList ( ) ;
cache . put ( model , params ) ;
return params ;
}
public String toJsonParam ( String model , String prompt ) {
return template . replace ( " ${model} " , model ) . replace ( " ${prompt} " , prompt ) ;
}
}
private static void printScriptParam ( ScriptParameter param ) throws IllegalAccessException {
System . out . println ( " 本次执行脚本的参数如下: " ) ;
for ( Field field : PARAM_FIELD ) {
if ( field . isAnnotationPresent ( EnvName . class ) ) {
System . out . println ( SEP ) ;
EnvName annotation = field . getAnnotation ( EnvName . class ) ;
field . setAccessible ( true ) ;
Object value = field . get ( param ) ;
System . out . printf ( " 参数: %s 数值: %s%n " , annotation . value ( ) , value ) ;
}
}
System . out . println ( SEP ) ;
}
private static void createRunScript ( ) throws IOException {
String osName = System . getProperty ( " os.name " ) . toLowerCase ( ) ;
if ( osName . contains ( " win " ) ) {
generateWindowsBat ( ) ;
} else {
generateShellScript ( ) ;
}
}
private static File createScripeFile ( String extName ) {
String date = LocalDateTime . now ( ) . format ( DateTimeFormatter . ofPattern ( " yyyyMMdd " ) ) ;
File file = new File ( String . format ( " llm_benchmark_tester_%s.%s " , date , extName ) ) ;
if ( file . exists ( ) ) {
throw new RuntimeException ( String . format ( " 您可以通过 %s 脚本直接运行 " , file . getAbsolutePath ( ) ) ) ;
}
boolean createBat ;
try {
createBat = file . createNewFile ( ) ;
} catch ( Exception e ) {
throw new RuntimeException ( String . format ( " 创建 %s 脚本文件异常: %s " , file . getAbsolutePath ( ) , e . getMessage ( ) ) , e ) ;
}
if ( ! createBat ) {
throw new RuntimeException ( String . format ( " 创建 %s 脚本文件失败 " , file . getAbsolutePath ( ) ) ) ;
}
System . out . println ( " 已为你生成一份脚本, 请修改脚本中的环境变量, 使用脚本运行 " ) ;
System . out . printf ( " 脚本的存储路径 %s%n " , file . getAbsolutePath ( ) ) ;
System . out . println ( " 运行脚本之前, 请确保脚本文件的换行符与系统相匹配, 否则会无法运行 " ) ;
return file ;
}
private static void writeScriptFile ( File file , String template , BiConsumer < EnvName , List < String > > eachFunc ) throws IOException {
List < String > envLines = new ArrayList < > ( ) ;
for ( Field field : PARAM_FIELD ) {
if ( field . isAnnotationPresent ( EnvName . class ) ) {
EnvName annotation = field . getAnnotation ( EnvName . class ) ;
eachFunc . accept ( annotation , envLines ) ;
}
}
if ( ! envLines . isEmpty ( ) ) {
String envList = envLines . stream ( ) . collect ( Collectors . joining ( System . lineSeparator ( ) ) ) ;
String script = template . replace ( " ${ENV_LINES} " , envList ) ;
Files . writeString ( file . toPath ( ) , script . strip ( ) , StandardCharsets . UTF_8 ) ;
}
}
private static void generateWindowsBat ( ) throws IOException {
File file = createScripeFile ( " bat " ) ;
String batTemplate = " " "
@echo off
: : java可执行文件路径 , 不是JAVA_HOME , 是完整的java可执行文件路径 , 例如 : D : \ \ jdk - 2108 \ \ bin \ \ java
set JAVA_BIN =
: : 脚本存放路径 , 例如 : E : \ \ JExample \ \ src \ \ LLMBenchmarkTester . java
set SCRIPT_PATH =
$ { ENV_LINES }
: : 基于环境变量的方式执行 , 交互式命令行执行用这个命令 : % JAVA_BIN % % SCRIPT_PATH % - p input
% JAVA_BIN % % SCRIPT_PATH % - p env
pause
" " " ;
writeScriptFile ( file , batTemplate , ( envName , envLines ) - > {
envLines . add ( " :: " + envName . desc ( ) ) ;
envLines . add ( " set " + envName . value ( ) + " = " ) ;
} ) ;
}
private static void generateShellScript ( ) throws IOException {
File file = createScripeFile ( " sh " ) ;
String bashTemplate = " " "
# ! / bin / bash
# java可执行文件路径 , 不是JAVA_HOME , 是完整的java可执行文件路径 , 例如 : / opt / jdk - 2108 / bin / java
JAVA_BIN = " "
# 脚本存放路径 , 例如 : / home / user / JExample / src / LLMBenchmarkTester . java
SCRIPT_PATH = " "
$ { ENV_LINES }
# 基于环境变量的方式执行 , 交互式命令行执行用这个命令 : $JAVA_BIN $SCRIPT_PATH - p input
" $JAVA_BIN " " $SCRIPT_PATH " - p env
" " " ;
writeScriptFile ( file , bashTemplate , ( envName , envLines ) - > {
envLines . add ( " # " + envName . desc ( ) ) ;
envLines . add ( envName . value ( ) + " = " ) ;
} ) ;
}
private static ScriptParameter readScriptParameter ( String [ ] args ) throws IllegalAccessException {
if ( args ! = null & & args . length > 0 ) {
boolean p = Arrays . stream ( args ) . anyMatch ( s - > s . equalsIgnoreCase ( " -p " ) ) ;
if ( p & & Arrays . stream ( args ) . anyMatch ( s - > s . equalsIgnoreCase ( " env " ) ) ) {
return initScriptParamFromEnv ( ) ;
}
if ( p & & Arrays . stream ( args ) . anyMatch ( s - > s . equalsIgnoreCase ( " input " ) ) ) {
return initScriptParamFromAsk ( ) ;
}
}
throw new RuntimeException ( " 命令错误, 请检查参数是否正确 " ) ;
}
private static ScriptParameter initScriptParamFromEnv ( ) throws IllegalAccessException {
ScriptParameter param = new ScriptParameter ( ) ;
param . channel = 1 ;
for ( Field field : PARAM_FIELD ) {
if ( field . isAnnotationPresent ( EnvName . class ) ) {
EnvName envName = field . getAnnotation ( EnvName . class ) ;
String fieldValue = System . getenv ( envName . value ( ) ) ;
String formatValue = formatValue ( fieldValue , field ) ;
if ( field . isAnnotationPresent ( NotBlank . class ) & & ( formatValue = = null | | formatValue . isBlank ( ) ) ) {
throw new RuntimeException ( String . format ( " 环境变量[%s]不能为空或空白字符 " , envName . value ( ) ) ) ;
}
if ( ! isValidValue ( formatValue , field ) ) {
throw new RuntimeException ( String . format ( " 环境变量[%s]数值不合法, 当前值:[%s] " , envName . value ( ) , formatValue ) ) ;
}
field . setAccessible ( true ) ;
field . set ( param , fieldValue ) ;
}
}
return param ;
}
private static ScriptParameter initScriptParamFromAsk ( ) throws IllegalAccessException {
ScriptParameter param = new ScriptParameter ( ) ;
param . channel = 2 ;
Scanner scanner = new Scanner ( System . in ) ;
for ( Field field : PARAM_FIELD ) {
if ( field . isAnnotationPresent ( AskUser . class ) ) {
AskUser askUser = field . getAnnotation ( AskUser . class ) ;
System . out . println ( askUser . value ( ) + " : " ) ;
boolean isNotBlank = field . isAnnotationPresent ( NotBlank . class ) ;
for ( ; ; ) {
String userInput = scanner . nextLine ( ) . trim ( ) ;
String formatValue = formatValue ( userInput , field ) ;
/ / 允许为空并且输入值为空
if ( ! isNotBlank & & ( formatValue = = null | | formatValue . isBlank ( ) ) ) {
break ;
}
/ / 非空并且输入值合法
if ( isNotBlank & & formatValue ! = null & & ! formatValue . isBlank ( ) & & isValidValue ( formatValue , field ) ) {
field . setAccessible ( true ) ;
field . set ( param , formatValue ) ;
break ;
}
System . out . print ( " 请重新输入: " ) ;
}
}
}
return param ;
}
/ / 顺序校验
private static Boolean isValidValue ( String formatValue , Field field ) {
if ( field . isAnnotationPresent ( Validator . class ) ) {
Validator anno = field . getAnnotation ( Validator . class ) ;
return Arrays . stream ( anno . value ( ) )
. map ( validator - > Constants . TEXT_VALIDATOR . get ( validator . name ( ) ) . apply ( formatValue ) )
. allMatch ( Boolean . TRUE : : equals ) ;
}
return Boolean . TRUE ;
}
/ / 顺序格式化
private static String formatValue ( String fieldValue , Field field ) {
if ( field . isAnnotationPresent ( Formatter . class ) ) {
Formatter anno = field . getAnnotation ( Formatter . class ) ;
for ( TextFormater fmt : anno . value ( ) ) {
fieldValue = Constants . TEXT_FORMATER . get ( fmt . name ( ) ) . apply ( fieldValue ) ;
}
}
return fieldValue ;
}
@Retention ( RetentionPolicy . RUNTIME )
@Target ( { ElementType . FIELD } )
public @interface AskUser {
String value ( ) ;
}
@Retention ( RetentionPolicy . RUNTIME )
@Target ( { ElementType . FIELD } )
public @interface EnvName {
String value ( ) default " " ;
String desc ( ) ;
}
@Retention ( RetentionPolicy . RUNTIME )
@Target ( { ElementType . FIELD } )
public @interface NotBlank {
}
@Retention ( RetentionPolicy . RUNTIME )
@Target ( { ElementType . FIELD } )
public @interface Validator {
TextValidator [ ] value ( ) ;
}
@Retention ( RetentionPolicy . RUNTIME )
@Target ( { ElementType . FIELD } )
public @interface Formatter {
TextFormater [ ] value ( ) ;
}
public enum TextValidator {
MUST_URL , MUST_FOLDER , MUST_TXT , MUST_NUM ;
}
public enum TextFormater {
STRIP , COMMA_CN_2_EN ;
}
private static class Constants {
/ / 去除字符串两端空白字符和制表符
public static final Function < String , String > STRIP_FORMATTER =
str - > Optional . ofNullable ( str ) . map ( java . lang . String : : strip ) . orElse ( " " ) ;
/ / 中文逗号替换成英文逗号
public static final Function < String , String > COMMA_CN_2_EN =
str - > Optional . ofNullable ( str ) . map ( d - > d . replaceAll ( " , " , " , " ) ) . orElse ( " " ) ;
/ / 字符串必须是一个http链接
public static final Function < String , Boolean > URL_VALIDATOR =
str - > str ! = null & & ( str . startsWith ( " http:// " ) | | str . startsWith ( " https:// " ) ) ;
/ / 字符串必须是一个合法的文件路径且已存在的文件夹
public static final Function < String , Boolean > FOLDER_VALIDATOR = str - > {
if ( str ! = null & & ! str . isBlank ( ) ) {
try {
File file = new File ( str ) ;
return file . exists ( ) & & file . isDirectory ( ) ;
} catch ( Exception e ) {
return false ;
}
}
return true ;
} ;
/ / 字符串必须是一个合法的文件路径且已存在的txt文件
public static final Function < String , Boolean > TXT_FILE_VALIDATOR = str - > {
if ( str ! = null & & ! str . isBlank ( ) ) {
try {
File file = new File ( str ) ;
return file . exists ( ) & & file . isFile ( ) & & file . getName ( ) . endsWith ( " .txt " ) ;
} catch ( Exception e ) {
return false ;
}
}
return true ;
} ;
/ / 字符串必须是一个整数
public static final Function < String , Boolean > NUMBER_VALIDATOR = str - > {
if ( str ! = null & & ! str . isBlank ( ) ) {
try {
Integer . parseInt ( str ) ;
} catch ( Exception e ) {
return false ;
}
}
return true ;
} ;
/ / 文本格式化工具注册表
public static final Map < String , Function < String , String > > TEXT_FORMATER =
Map . of (
TextFormater . STRIP . name ( ) , Constants . STRIP_FORMATTER ,
TextFormater . COMMA_CN_2_EN . name ( ) , Constants . COMMA_CN_2_EN
) ;
/ / 文本验证工具注册表
public static final Map < String , Function < String , Boolean > > TEXT_VALIDATOR =
Map . of (
TextValidator . MUST_URL . name ( ) , URL_VALIDATOR ,
TextValidator . MUST_FOLDER . name ( ) , FOLDER_VALIDATOR ,
TextValidator . MUST_TXT . name ( ) , TXT_FILE_VALIDATOR ,
TextValidator . MUST_NUM . name ( ) , NUMBER_VALIDATOR
) ;
}
public static class ScriptParameter {
/ / 1 = 环境变量 , 2 = 交互式命令行
int channel ;
@NotBlank
@EnvName ( value = " BENCH_LLM_API_HOST " , desc = " OpenAI API 的访问地址, 例如: http://localhost:8080/v1/chat/completions " )
@Validator ( value = TextValidator . MUST_URL )
@Formatter ( value = TextFormater . STRIP )
@AskUser ( value = " 请输入 OpenAI API 的访问地址 (例如: http://localhost:8080/v1/chat/completions) " )
String openAiApiHost ;
@NotBlank
@EnvName ( value = " BENCH_LLM_API_KEY " , desc = " ApiKey或者叫API令牌 " )
@Formatter ( value = TextFormater . STRIP )
@AskUser ( value = " 请输入ApiKey或者叫API令牌 " )
String apiKey ;
@NotBlank
@EnvName ( value = " BENCH_THREAD_SIZE_ARRAY " , desc = " 请输入线程池配置, 示例值: 10,50,100 " )
@Formatter ( value = { TextFormater . STRIP , TextFormater . COMMA_CN_2_EN } )
@AskUser ( value = " 请输入线程池配置 (示例值: 10,50,100) " )
String threadSize ;
@EnvName ( value = " BENCH_LLM_MODEL_NAME " , desc = " 文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以不设置, 示例值: qwen2.5,qwen3 " )
@Formatter ( value = { TextFormater . STRIP , TextFormater . COMMA_CN_2_EN } )
@AskUser ( value = " 请输入文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以直接回车 (示例值: qwen2.5,qwen3) " )
String modelName ;
@EnvName ( value = " BENCH_LLM_VL_MODEL_NAME " , desc = " VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以不设置 " )
@Formatter ( value = { TextFormater . STRIP , TextFormater . COMMA_CN_2_EN } )
@AskUser ( value = " 请输入VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以直接回车 " )
String vlModelName ;
@EnvName ( value = " BENCH_LLM_VL_IMG_FOLDER " , desc = " 调用VL模型的图片存储目录, 如果不测试VL模型可以不设置, 示例值: /home/image " )
@Validator ( value = TextValidator . MUST_FOLDER )
@Formatter ( value = { TextFormater . STRIP } )
@AskUser ( value = " 请输入调用VL模型的图片存储目录, 如果不测试VL模型可以直接回车 (示例值: /home/image) " )
String vlImgFolder ;
@EnvName ( value = " BENCH_LLM_CHAT_MODEL_DATASETS " , desc = " 文生文测试数据集的文件路径, 如果不测试文生文模型可以不设置, 必须是一个.txt文件 (示例值: /home/datasets.txt) " )
@Validator ( value = TextValidator . MUST_TXT )
@AskUser ( value = " 请输入文生文测试数据集的文件路径, 必须是一个.txt文件, 如果不测试文生文模型可以直接回车 (示例值: /home/datasets.txt) " )
String chatDatasetsPath ;
@EnvName ( value = " BENCH_LLM_VL_IMG_SIZE " , desc = " 调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以不设置 (示例值: 300) " )
@Validator ( value = TextValidator . MUST_NUM )
@Formatter ( value = TextFormater . STRIP )
@AskUser ( " 请输入调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以直接回车 (示例值: 300) " )
String imgSize ;
public boolean isTestChatModel ( ) {
return this . modelName ! = null & & ! this . modelName . isBlank ( )
& & this . chatDatasetsPath ! = null & & ! this . chatDatasetsPath . isBlank ( ) ;
}
public boolean isTestVlModel ( ) {
return this . vlModelName ! = null & & ! this . vlModelName . isBlank ( )
& & this . vlImgFolder ! = null & & ! this . vlImgFolder . isBlank ( )
& & this . imgSize ! = null & & ! this . imgSize . isBlank ( ) ;
}
}
}