1
1
// 记得在 GREETTING 消息里说明可以使用的功能
2
2
// TODO: 写好之后更新一下使用文档的 2.5
3
- import type { AIFunction } from '../types'
3
+ import { type AIFunction , ALLOWED_DISCRETE_METHODS } from '../types'
4
4
import { useAssistant } from '../lib/useAssistant'
5
5
import { useZustand } from '../lib/useZustand'
6
6
import { useState , useRef , useEffect } from 'react'
@@ -34,6 +34,7 @@ import { nav_to_plots_view } from '../lib/assistant/nav_to_plots_view'
34
34
import { nav_to_statistics_view } from '../lib/assistant/nav_to_statistics_view'
35
35
import { nav_to_tools_view } from '../lib/assistant/nav_to_tools_view'
36
36
import { create_new_var } from '../lib/assistant/create_new_var'
37
+ import { create_sub_var , clear_sub_var } from '../lib/assistant/create_sub_var'
37
38
38
39
const md = markdownit ( { html : true , breaks : true } )
39
40
const funcs : AIFunction [ ] = [
@@ -44,9 +45,11 @@ const funcs: AIFunction[] = [
44
45
nav_to_statistics_view ,
45
46
nav_to_tools_view ,
46
47
create_new_var ,
48
+ create_sub_var ,
49
+ clear_sub_var ,
47
50
]
48
51
const GREETTING =
49
- '你好, 我是 PsychPen 的 AI 助手, 可以帮你讲解 PsychPen 的使用方法、探索你的数据集、导出数据、跳转页面、生成新变量等. 请问有什么可以帮你的?'
52
+ '你好, 我是 PsychPen 的 AI 助手, 可以帮你讲解 PsychPen 的使用方法、探索你的数据集、导出数据、跳转页面、生成/清除子变量 (标准化/中心化/离散化)、 生成新变量等. 请问有什么可以帮你的?'
50
53
51
54
export function AI ( ) {
52
55
const { ai, model } = useAssistant ( )
@@ -147,6 +150,35 @@ export function AI() {
147
150
148
151
try {
149
152
switch ( toolCall . function . name ) {
153
+ case 'clear_sub_var' : {
154
+ const { variable_names } = JSON . parse ( toolCall . function . arguments )
155
+ if (
156
+ ! Array . isArray ( variable_names ) ||
157
+ ! variable_names . every ( ( name ) => ( typeof name === 'string' ) && dataCols . some ( ( col ) => col . name === name ) )
158
+ ) {
159
+ throw new Error ( '变量名参数错误' )
160
+ }
161
+ newMessages [ 1 ] . content = `已请求清除变量 ${ ( variable_names as string [ ] ) . map ( ( name ) => `"${ name } "` ) . join ( '、' ) } 的所有子变量, 等待用户手动确认`
162
+ break
163
+ }
164
+ case 'create_sub_var' : {
165
+ const { variable_names, standardize, centralize, discretize } =
166
+ JSON . parse ( toolCall . function . arguments )
167
+ if (
168
+ ! Array . isArray ( variable_names ) ||
169
+ ! variable_names . every ( ( name ) => ( typeof name === 'string' ) && dataCols . some ( ( col ) => col . name === name ) )
170
+ ) {
171
+ throw new Error ( '变量名参数错误' )
172
+ }
173
+ newMessages [ 1 ] . content = `已请求生成变量 ${ ( variable_names as string [ ] ) . map ( ( name ) => `"${ name } "` ) . join ( '、' ) } 的${ [
174
+ standardize ? '标准化' : '' ,
175
+ centralize ? '中心化' : '' ,
176
+ discretize ? '离散化' : '' ,
177
+ ]
178
+ . filter ( ( part ) => part )
179
+ . join ( '、' ) } 子变量, 等待用户手动确认`
180
+ break
181
+ }
150
182
case 'create_new_var' : {
151
183
// eslint-disable-next-line @typescript-eslint/no-unused-vars
152
184
const { variable_name, calc_expression : _ } = JSON . parse (
@@ -397,12 +429,150 @@ function ToolCall({ toolCall }: { toolCall: ChatCompletionMessageToolCall }) {
397
429
const id = toolCall . id
398
430
const name = toolCall . function . name
399
431
const args = toolCall . function . arguments
400
- const { dataRows, _VariableView_addNewVar, messageApi } = useZustand ( )
432
+ const { dataRows, _VariableView_addNewVar, messageApi, dataCols , _VariableView_updateData } = useZustand ( )
401
433
const [ done , setDone ] = useState ( false )
402
434
const formerDone = sessionStorage . getItem ( id ) === 'done'
403
435
let element : React . ReactElement | null = null
404
436
let initDone = true
405
437
switch ( name ) {
438
+ case 'clear_sub_var' : {
439
+ const { variable_names } = JSON . parse ( args ) as { variable_names : string [ ] }
440
+ if ( ! formerDone ) {
441
+ initDone = false
442
+ }
443
+ element = (
444
+ < >
445
+ < div >
446
+ 执行函数{ ' ' }
447
+ < Tag color = 'blue' style = { { margin : 0 } } >
448
+ { funcs . find (
449
+ ( func ) => func . tool . function . name === toolCall . function . name ,
450
+ ) ?. label || `未知函数 (${ toolCall . function . name } )` }
451
+ </ Tag >
452
+ { done ? ', 已' : ', 是否确认' } 清除变量
453
+ { variable_names . map ( ( name ) => (
454
+ < Tag key = { name } style = { { margin : 0 , marginLeft : '0.3rem' } } color = 'blue' >
455
+ { name }
456
+ </ Tag >
457
+ ) ) } { ' ' }
458
+ 的所有子变量
459
+ </ div >
460
+ < div >
461
+ < Button
462
+ block
463
+ disabled = { done }
464
+ onClick = { ( ) => {
465
+ _VariableView_updateData (
466
+ dataCols . map ( ( col ) => {
467
+ if ( variable_names . includes ( col . name ) ) {
468
+ return {
469
+ ...col ,
470
+ subVars : undefined ,
471
+ }
472
+ }
473
+ return col
474
+ } ) . filter ( ( col ) => col . derived !== true ) ,
475
+ )
476
+ setDone ( true )
477
+ sessionStorage . setItem ( id , 'done' )
478
+ messageApi ?. success ( `已成功清除变量 ${ variable_names . map ( ( name ) => `"${ name } "` ) . join ( '、' ) } 的所有子变量` )
479
+ } }
480
+ >
481
+ { done ? '已清除子变量' : '确认清除子变量' }
482
+ </ Button >
483
+ </ div >
484
+ </ >
485
+ )
486
+ break
487
+ }
488
+ case 'create_sub_var' : {
489
+ const { variable_names, standardize, centralize, discretize } = JSON . parse ( args ) as {
490
+ variable_names : string [ ]
491
+ standardize : boolean | undefined
492
+ centralize : boolean | undefined
493
+ discretize : {
494
+ method : ALLOWED_DISCRETE_METHODS
495
+ groups : number
496
+ } | undefined
497
+ }
498
+ if ( ! formerDone ) {
499
+ initDone = false
500
+ }
501
+ const ALLOWED_METHOD = Object . values ( ALLOWED_DISCRETE_METHODS )
502
+ const shouldDiscritize = Boolean (
503
+ typeof discretize === 'object' &&
504
+ discretize . method &&
505
+ discretize . groups &&
506
+ ALLOWED_METHOD . includes ( discretize . method )
507
+ )
508
+ element = (
509
+ < >
510
+ < div >
511
+ 执行函数{ ' ' }
512
+ < Tag color = 'blue' style = { { margin : 0 } } >
513
+ { funcs . find (
514
+ ( func ) => func . tool . function . name === toolCall . function . name ,
515
+ ) ?. label || `未知函数 (${ toolCall . function . name } )` }
516
+ </ Tag >
517
+ { done ? ', 已' : ', 是否确认' } 生成变量
518
+ { variable_names . map ( ( name ) => (
519
+ < Tag key = { name } style = { { margin : 0 , marginLeft : '0.3rem' } } color = 'blue' >
520
+ { name }
521
+ </ Tag >
522
+ ) ) } { ' ' }
523
+ 的
524
+ { [
525
+ standardize ? '标准化' : '' ,
526
+ centralize ? '中心化' : '' ,
527
+ shouldDiscritize ? `离散化 (${ discretize ! . method } , ${ discretize ! . groups } 组) ` : '' ,
528
+ ]
529
+ . filter ( ( part ) => part )
530
+ . join ( '、' ) }
531
+ 子变量
532
+ </ div >
533
+ < div >
534
+ < Button
535
+ block
536
+ disabled = { done }
537
+ onClick = { ( ) => {
538
+ _VariableView_updateData (
539
+ dataCols . map ( ( col ) => {
540
+ if ( variable_names . includes ( col . name ) ) {
541
+ return {
542
+ ...col ,
543
+ subVars : {
544
+ standard : Boolean ( standardize ) || col . subVars ?. standard ,
545
+ center : Boolean ( centralize ) || col . subVars ?. center ,
546
+ discrete : shouldDiscritize
547
+ ? {
548
+ method : discretize ! . method ,
549
+ groups : discretize ! . groups ,
550
+ }
551
+ : col . subVars ?. discrete ,
552
+ } ,
553
+ }
554
+ }
555
+ return col
556
+ } ) . filter ( ( col ) => col . derived !== true ) ,
557
+ )
558
+ setDone ( true )
559
+ sessionStorage . setItem ( id , 'done' )
560
+ messageApi ?. success ( `已成功生成变量 ${ variable_names . map ( ( name ) => `"${ name } "` ) . join ( '、' ) } 的${ [
561
+ standardize ? '标准化' : '' ,
562
+ centralize ? '中心化' : '' ,
563
+ shouldDiscritize ? '离散化' : '' ,
564
+ ]
565
+ . filter ( ( part ) => part )
566
+ . join ( '、' ) } 子变量`)
567
+ } }
568
+ >
569
+ { done ? '已生成子变量' : '确认生成子变量' }
570
+ </ Button >
571
+ </ div >
572
+ </ >
573
+ )
574
+ break
575
+ }
406
576
case 'create_new_var' : {
407
577
const { variable_name, calc_expression } = JSON . parse ( args )
408
578
if ( ! formerDone ) {
0 commit comments