Skip to content

Commit 99b0f06

Browse files
committed
feat: AI助手现在可以直接操作生成/清除中心化/标准化/离散化子变量
1 parent 9bd7e8b commit 99b0f06

File tree

5 files changed

+271
-9
lines changed

5 files changed

+271
-9
lines changed

src/components/AI.tsx

+173-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
// 记得在 GREETTING 消息里说明可以使用的功能
22
// TODO: 写好之后更新一下使用文档的 2.5
3-
import type { AIFunction } from '../types'
3+
import { type AIFunction, ALLOWED_DISCRETE_METHODS } from '../types'
44
import { useAssistant } from '../lib/useAssistant'
55
import { useZustand } from '../lib/useZustand'
66
import { useState, useRef, useEffect } from 'react'
@@ -34,6 +34,7 @@ import { nav_to_plots_view } from '../lib/assistant/nav_to_plots_view'
3434
import { nav_to_statistics_view } from '../lib/assistant/nav_to_statistics_view'
3535
import { nav_to_tools_view } from '../lib/assistant/nav_to_tools_view'
3636
import { create_new_var } from '../lib/assistant/create_new_var'
37+
import { create_sub_var, clear_sub_var } from '../lib/assistant/create_sub_var'
3738

3839
const md = markdownit({ html: true, breaks: true })
3940
const funcs: AIFunction[] = [
@@ -44,9 +45,11 @@ const funcs: AIFunction[] = [
4445
nav_to_statistics_view,
4546
nav_to_tools_view,
4647
create_new_var,
48+
create_sub_var,
49+
clear_sub_var,
4750
]
4851
const GREETTING =
49-
'你好, 我是 PsychPen 的 AI 助手, 可以帮你讲解 PsychPen 的使用方法、探索你的数据集、导出数据、跳转页面、生成新变量等. 请问有什么可以帮你的?'
52+
'你好, 我是 PsychPen 的 AI 助手, 可以帮你讲解 PsychPen 的使用方法、探索你的数据集、导出数据、跳转页面、生成/清除子变量 (标准化/中心化/离散化)、生成新变量等. 请问有什么可以帮你的?'
5053

5154
export function AI() {
5255
const { ai, model } = useAssistant()
@@ -147,6 +150,35 @@ export function AI() {
147150

148151
try {
149152
switch (toolCall.function.name) {
153+
case 'clear_sub_var': {
154+
const { variable_names } = JSON.parse(toolCall.function.arguments)
155+
if (
156+
!Array.isArray(variable_names) ||
157+
!variable_names.every((name) => (typeof name === 'string') && dataCols.some((col) => col.name === name))
158+
) {
159+
throw new Error('变量名参数错误')
160+
}
161+
newMessages[1].content = `已请求清除变量 ${(variable_names as string[]).map((name) => `"${name}"`).join('、')} 的所有子变量, 等待用户手动确认`
162+
break
163+
}
164+
case 'create_sub_var': {
165+
const { variable_names, standardize, centralize, discretize } =
166+
JSON.parse(toolCall.function.arguments)
167+
if (
168+
!Array.isArray(variable_names) ||
169+
!variable_names.every((name) => (typeof name === 'string') && dataCols.some((col) => col.name === name))
170+
) {
171+
throw new Error('变量名参数错误')
172+
}
173+
newMessages[1].content = `已请求生成变量 ${(variable_names as string[]).map((name) => `"${name}"`).join('、')}${[
174+
standardize ? '标准化' : '',
175+
centralize ? '中心化' : '',
176+
discretize ? '离散化' : '',
177+
]
178+
.filter((part) => part)
179+
.join('、')}子变量, 等待用户手动确认`
180+
break
181+
}
150182
case 'create_new_var': {
151183
// eslint-disable-next-line @typescript-eslint/no-unused-vars
152184
const { variable_name, calc_expression: _ } = JSON.parse(
@@ -397,12 +429,150 @@ function ToolCall({ toolCall }: { toolCall: ChatCompletionMessageToolCall }) {
397429
const id = toolCall.id
398430
const name = toolCall.function.name
399431
const args = toolCall.function.arguments
400-
const { dataRows, _VariableView_addNewVar, messageApi } = useZustand()
432+
const { dataRows, _VariableView_addNewVar, messageApi, dataCols, _VariableView_updateData } = useZustand()
401433
const [done, setDone] = useState(false)
402434
const formerDone = sessionStorage.getItem(id) === 'done'
403435
let element: React.ReactElement | null = null
404436
let initDone = true
405437
switch (name) {
438+
case 'clear_sub_var': {
439+
const { variable_names } = JSON.parse(args) as { variable_names: string[] }
440+
if (!formerDone) {
441+
initDone = false
442+
}
443+
element = (
444+
<>
445+
<div>
446+
执行函数{' '}
447+
<Tag color='blue' style={{ margin: 0 }}>
448+
{funcs.find(
449+
(func) => func.tool.function.name === toolCall.function.name,
450+
)?.label || `未知函数 (${toolCall.function.name})`}
451+
</Tag>
452+
{done ? ', 已' : ', 是否确认'}清除变量
453+
{variable_names.map((name) => (
454+
<Tag key={name} style={{ margin: 0, marginLeft: '0.3rem' }} color='blue'>
455+
{name}
456+
</Tag>
457+
))}{' '}
458+
的所有子变量
459+
</div>
460+
<div>
461+
<Button
462+
block
463+
disabled={done}
464+
onClick={() => {
465+
_VariableView_updateData(
466+
dataCols.map((col) => {
467+
if (variable_names.includes(col.name)) {
468+
return {
469+
...col,
470+
subVars: undefined,
471+
}
472+
}
473+
return col
474+
}).filter((col) => col.derived !== true),
475+
)
476+
setDone(true)
477+
sessionStorage.setItem(id, 'done')
478+
messageApi?.success(`已成功清除变量 ${variable_names.map((name) => `"${name}"`).join('、')} 的所有子变量`)
479+
}}
480+
>
481+
{done ? '已清除子变量' : '确认清除子变量'}
482+
</Button>
483+
</div>
484+
</>
485+
)
486+
break
487+
}
488+
case 'create_sub_var': {
489+
const { variable_names, standardize, centralize, discretize } = JSON.parse(args) as {
490+
variable_names: string[]
491+
standardize: boolean | undefined
492+
centralize: boolean | undefined
493+
discretize: {
494+
method: ALLOWED_DISCRETE_METHODS
495+
groups: number
496+
} | undefined
497+
}
498+
if (!formerDone) {
499+
initDone = false
500+
}
501+
const ALLOWED_METHOD = Object.values(ALLOWED_DISCRETE_METHODS)
502+
const shouldDiscritize = Boolean(
503+
typeof discretize === 'object' &&
504+
discretize.method &&
505+
discretize.groups &&
506+
ALLOWED_METHOD.includes(discretize.method)
507+
)
508+
element = (
509+
<>
510+
<div>
511+
执行函数{' '}
512+
<Tag color='blue' style={{ margin: 0 }}>
513+
{funcs.find(
514+
(func) => func.tool.function.name === toolCall.function.name,
515+
)?.label || `未知函数 (${toolCall.function.name})`}
516+
</Tag>
517+
{done ? ', 已' : ', 是否确认'}生成变量
518+
{variable_names.map((name) => (
519+
<Tag key={name} style={{ margin: 0, marginLeft: '0.3rem' }} color='blue'>
520+
{name}
521+
</Tag>
522+
))}{' '}
523+
524+
{[
525+
standardize ? '标准化' : '',
526+
centralize ? '中心化' : '',
527+
shouldDiscritize ? `离散化 (${discretize!.method}, ${discretize!.groups} 组) ` : '',
528+
]
529+
.filter((part) => part)
530+
.join('、')}
531+
子变量
532+
</div>
533+
<div>
534+
<Button
535+
block
536+
disabled={done}
537+
onClick={() => {
538+
_VariableView_updateData(
539+
dataCols.map((col) => {
540+
if (variable_names.includes(col.name)) {
541+
return {
542+
...col,
543+
subVars: {
544+
standard: Boolean(standardize) || col.subVars?.standard,
545+
center: Boolean(centralize) || col.subVars?.center,
546+
discrete: shouldDiscritize
547+
? {
548+
method: discretize!.method,
549+
groups: discretize!.groups,
550+
}
551+
: col.subVars?.discrete,
552+
},
553+
}
554+
}
555+
return col
556+
}).filter((col) => col.derived !== true),
557+
)
558+
setDone(true)
559+
sessionStorage.setItem(id, 'done')
560+
messageApi?.success(`已成功生成变量 ${variable_names.map((name) => `"${name}"`).join('、')}${[
561+
standardize ? '标准化' : '',
562+
centralize ? '中心化' : '',
563+
shouldDiscritize ? '离散化' : '',
564+
]
565+
.filter((part) => part)
566+
.join('、')}子变量`)
567+
}}
568+
>
569+
{done ? '已生成子变量' : '确认生成子变量'}
570+
</Button>
571+
</div>
572+
</>
573+
)
574+
break
575+
}
406576
case 'create_new_var': {
407577
const { variable_name, calc_expression } = JSON.parse(args)
408578
if (!formerDone) {

src/components/variable/DataFilter.tsx

+4-1
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ export function DataFilter() {
209209
method === ALLOWED_FILTER_METHODS.BELOW_MEDIAN
210210
) {
211211
return null
212-
} else if (method === ALLOWED_FILTER_METHODS.EQUAL || method === ALLOWED_FILTER_METHODS.NOT_EQUAL) {
212+
} else if (
213+
method === ALLOWED_FILTER_METHODS.EQUAL ||
214+
method === ALLOWED_FILTER_METHODS.NOT_EQUAL
215+
) {
213216
const variable = form.getFieldValue('variable')
214217
const options = Array.from(new Set(dataRows.map((row) => row[variable])))
215218
.sort((a, b) => Number(a) - Number(b))

src/components/variable/SubVariables.tsx

+6-4
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ type Option = {
1616
discretizeGroups?: number
1717
}
1818

19-
const DISCRETE_METHODS = Object.values(ALLOWED_DISCRETE_METHODS).map((method) => ({
20-
label: method,
21-
value: method,
22-
}))
19+
const DISCRETE_METHODS = Object.values(ALLOWED_DISCRETE_METHODS).map(
20+
(method) => ({
21+
label: method,
22+
value: method,
23+
}),
24+
)
2325
const ALLOW_SUBVARS: {
2426
en: string
2527
cn: string

src/lib/assistant/create_sub_var.ts

+83
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import type { AIFunction } from '../../types'
2+
import { ALLOWED_DISCRETE_METHODS } from '../../types'
3+
4+
export const create_sub_var: AIFunction = {
5+
label: '生成子变量',
6+
tool: {
7+
type: 'function',
8+
function: {
9+
name: 'create_sub_var',
10+
description:
11+
'你可以调用这个函数来帮助用户把指定变量的标准化/中心化/离散化子变量',
12+
parameters: {
13+
type: 'object',
14+
properties: {
15+
variable_names: {
16+
type: 'array',
17+
description: '所有要生成子变量的变量名',
18+
items: {
19+
type: 'string',
20+
description: '变量名',
21+
},
22+
},
23+
standardize: {
24+
type: 'boolean',
25+
description: '是否生成标准化子变量',
26+
},
27+
centralize: {
28+
type: 'boolean',
29+
description: '是否生成中心化子变量',
30+
},
31+
discretize: {
32+
type: 'object',
33+
description: '如果要生成离散化子变量, 则需要指定离散化算法和分组数',
34+
properties: {
35+
method: {
36+
type: 'string',
37+
description: `离散化算法, 可选值为: ${Object.values(
38+
ALLOWED_DISCRETE_METHODS,
39+
).join(', ')}`,
40+
},
41+
groups: {
42+
type: 'number',
43+
description: '离散化分组数',
44+
},
45+
},
46+
required: ['method', 'groups'],
47+
additionalProperties: false,
48+
},
49+
},
50+
required: ['variable_name'],
51+
additionalProperties: false,
52+
},
53+
strict: true,
54+
},
55+
},
56+
}
57+
58+
export const clear_sub_var: AIFunction = {
59+
label: '清除子变量',
60+
tool: {
61+
type: 'function',
62+
function: {
63+
name: 'clear_sub_var',
64+
description: '你可以调用这个函数来帮助用户清除指定变量的所有子变量',
65+
parameters: {
66+
type: 'object',
67+
properties: {
68+
variable_names: {
69+
type: 'array',
70+
description: '所有要清除子变量的变量名',
71+
items: {
72+
type: 'string',
73+
description: '变量名',
74+
},
75+
},
76+
},
77+
required: ['variable_name'],
78+
additionalProperties: false,
79+
},
80+
strict: true,
81+
},
82+
},
83+
}

src/lib/calculates/derive.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,11 @@ class Discrete {
9797
* @param groups 分组数
9898
* @param methed 离散化方法
9999
*/
100-
constructor(data: number[], groups: number, methed: ALLOWED_DISCRETE_METHODS) {
100+
constructor(
101+
data: number[],
102+
groups: number,
103+
methed: ALLOWED_DISCRETE_METHODS,
104+
) {
101105
this.method = methed
102106
this.groups = groups
103107
this.#data = data.toSorted((a, b) => a - b)

0 commit comments

Comments
 (0)