
【KAN】KAN神经网络学习训练营(5)——hypothesis.py
整个代码主要围绕模型输入变量间的相互作用、可分性和对称性检测展开,并进一步构建变量组合的层次结构(树形图),以解释模型内部结构及其拟合机制。每个函数都利用数值计算(例如梯度和 Hessian)及聚类或统计方法对模型进行深入分析。
一、引言
KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。
二、技术与原理简介
1.Kolmogorov-Arnold 表示定理
Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑
其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。,
2.Kolmogorov-Arnold 网络 (KAN)
Kolmogorov-Arnold 表示可以写成矩阵形式
其中
我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:
其中。
定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是
相反,多层感知器由线性层和非线错:
KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。
三、代码详解
整个代码主要围绕模型输入变量间的相互作用、可分性和对称性检测展开,并进一步构建变量组合的层次结构(树形图),以解释模型内部结构及其拟合机制。每个函数都利用数值计算(例如梯度和 Hessian)及聚类或统计方法对模型进行深入分析。
A. 代码详解
可分离性检测函数
(1) detect_separability
def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
'''
detect function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
n_clusters : None or int
the number of clusters
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
results (dictionary)
Example1
--------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='add')
Example2
--------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='mul')
'''
results = {}
if mode == 'add':
hessian = batch_hessian(model, x)
elif mode == 'mul':
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
std = torch.std(x, dim=0)
hessian_normalized = hessian * std[None,:] * std[:,None]
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
results['hessian'] = score_mat
dist_hard = (score_mat < score_th).float()
if isinstance(n_clusters, int):
n_cluster_try = [n_clusters, n_clusters]
elif isinstance(n_clusters, list):
n_cluster_try = n_clusters
else:
n_cluster_try = [1,x.shape[1]]
n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
for n_cluster in n_cluster_try:
clustering = AgglomerativeClustering(
metric='precomputed',
n_clusters=n_cluster,
linkage='complete',
).fit(dist_hard)
labels = clustering.labels_
groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
block_sum = torch.sum(torch.stack(blocks))
total_sum = torch.sum(score_mat)
residual_sum = total_sum - block_sum
residual_ratio = residual_sum / total_sum
if verbose == True:
print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
if residual_ratio < res_th:
results['n_groups'] = n_cluster
results['labels'] = list(labels)
results['groups'] = groups
if results['n_groups'] > 1:
print(f'{mode} separability detected')
else:
print(f'{mode} separability not detected')
return results
- 功能:检测函数是否为加法可分离(mode='add')或乘法可分离(mode='mul')。
- 工作原理:计算模型的Hessian(二阶导数),并通过分析Hessian的标准化版本确定函数是否可以分解为变量组的独立部分。对于加法可分离,直接检查Hessian;对于乘法可分离,通过对数变换检查日志空间中的可分离性。
- 输出:返回一个字典,包含Hessian得分、组数(如可分离)及组标签。
- 示例:对于函数
,可以检测其加法可分离性。
(2) test_separability
def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
'''
test function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
>>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
'''
if mode == 'add':
hessian = batch_hessian(model, x)
elif mode == 'mul':
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
std = torch.std(x, dim=0)
hessian_normalized = hessian * std[None,:] * std[:,None]
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
sep_bool = True
# internal test
n_groups = len(groups)
for i in range(n_groups):
for j in range(i+1, n_groups):
sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
# external test
group_id = [x for xs in groups for x in xs]
nongroup_id = list(set(range(x.shape[1])) - set(group_id))
if len(nongroup_id) > 0 and len(group_id) > 0:
sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold
return sep_bool
- 功能:给定变量组,测试函数是否可分离。
- 工作原理:计算Hessian,检查组间块是否低于阈值,表明可分离。对于加法和乘法模式分别处理。
- 输出:返回布尔值,表示是否可分离。
- 示例:测试
在组[[0],[1,2]]下的乘法可分离性。
(3) test_general_separability
def test_general_separability(model, x, groups, threshold=1e-2):
'''
test function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_general_separability(model, x, [[1],[0,2]])) # False
>>> print(test_general_separability(model, x, [[0],[1,2]])) # True
'''
grad = batch_jacobian(model, x)
gensep_bool = True
n_groups = len(groups)
for i in range(n_groups):
for j in range(i+1,n_groups):
group_A = groups[i]
group_B = groups[j]
for member_A in group_A:
for member_B in group_B:
def func(x):
grad = batch_jacobian(model, x, create_graph=True)
return grad[:,[member_B]]/grad[:,[member_A]]
# test if func is multiplicative separable
gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
return gensep_bool
- 功能:测试更一般的可分离性,不限于加法或乘法。
- 工作原理:计算Jacobian(一阶导数),检查组间梯度比是否乘法可分离。
- 输出:返回布尔值,表示是否一般可分离。
- 示例:测试
在组[[0],[1,2]]下的可分离性。
对称性测试函数
(4) test_symmetry
def test_symmetry(model, x, group, dependence_th=1e-3):
'''
detect function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
group : a list of indices
dependence_th : float
threshold of dependence
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_symmetry(model, x, [1,2])) # True
>>> print(test_symmetry(model, x, [0,2])) # False
'''
if len(group) == x.shape[1] or len(group) == 0:
return True
dependence = get_dependence(model, x, group)
max_dependence = torch.max(dependence)
return max_dependence < dependence_th
- 功能:测试函数在给定变量组下的对称性。
- 工作原理:计算组与剩余变量的依赖性,检查最大依赖是否低于阈值。
- 输出:返回布尔值,表示是否对称。
- 示例:测试
在组[1,2]下的对称性。
(5) test_symmetry_var
def test_symmetry_var(model, x, input_vars, symmetry_var):
'''
test symmetry
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
input_vars : list of sympy symbols
symmetry_var : sympy expression
Returns:
--------
cosine similarity
Example
-------
>>> from kan.hypothesis import *
>>> from sympy import *
>>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
>>> x = torch.normal(0,1,size=(100,8))
>>> input_vars = a, b, c = symbols('a b c')
>>> symmetry_var = b + c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
>>> symmetry_var = b * c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
'''
orig_vars = input_vars
sym_var = symmetry_var
# gradients wrt to input (model)
input_grad = batch_jacobian(model, x)
# gradients wrt to input (symmetry var)
func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function
func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
sym_grad = batch_jacobian(func2, x)
# get id
idx = []
sym_symbols = list(sym_var.free_symbols)
for sym_symbol in sym_symbols:
for j in range(len(orig_vars)):
if sym_symbol == orig_vars[j]:
idx.append(j)
input_grad_part = input_grad[:,idx]
sym_grad_part = sym_grad[:,idx]
cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
ratio = torch.sum(cossim > 0.9)/len(cossim)
print(f'{100*ratio}% data have more than 0.9 cosine similarity')
if ratio > 0.9:
print('suggesting symmetry')
else:
print('not suggesting symmetry')
return cossim
- 功能:测试函数相对于符号表达式的对称性。
- 工作原理:计算模型和符号表达式的梯度,测量余弦相似度,确定对称性。
- 输出:打印高相似度数据百分比,并建议是否对称。
- 示例:测试
相对于
的对称性。
层次结构分析函数
(6) get_molecule
def get_molecule(model, x, sym_th=1e-3, verbose=True):
'''
how variables are combined hierarchically
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
sym_th : float
threshold of symmetry
verbose : bool
Returns:
--------
list
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> get_molecule(model, x, verbose=False)
[[[0], [1], [2], [3], [4], [5], [6], [7]],
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[0, 1, 2, 3], [4, 5, 6, 7]],
[[0, 1, 2, 3, 4, 5, 6, 7]]]
'''
n = x.shape[1]
atoms = [[i] for i in range(n)]
molecules = []
moleculess = [copy.deepcopy(atoms)]
already_full = False
n_layer = 0
last_n_molecule = n
while True:
pointer = 0
current_molecule = []
remove_atoms = []
n_atom = 0
while len(atoms) > 0:
# assemble molecule
atom = atoms[pointer]
if verbose:
print(current_molecule)
print(atom)
if len(current_molecule) == 0:
full = False
current_molecule += atom
remove_atoms.append(atom)
n_atom += 1
else:
# try assemble the atom to the molecule
if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
full = True
already_full = True
else:
full = False
if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
current_molecule += atom
remove_atoms.append(atom)
n_atom += 1
pointer += 1
if pointer == len(atoms) or full:
molecules.append(current_molecule)
if full:
molecules.append(atom)
remove_atoms.append(atom)
# remove molecules from atoms
for atom in remove_atoms:
atoms.remove(atom)
current_molecule = []
remove_atoms = []
pointer = 0
# if not making progress, terminate
if len(molecules) == last_n_molecule:
def flatten(xss):
return [x for xs in xss for x in xs]
moleculess.append([flatten(molecules)])
break
else:
moleculess.append(copy.deepcopy(molecules))
last_n_molecule = len(molecules)
if len(molecules) == 1:
break
atoms = molecules
molecules = []
n_layer += 1
#print(n_layer, atoms)
# sort
depth = len(moleculess) - 1
for l in list(range(depth,0,-1)):
molecules_sorted = []
molecules_l = moleculess[l]
molecules_lm1 = moleculess[l-1]
for molecule_l in molecules_l:
start = 0
for i in range(1,len(molecule_l)+1):
if molecule_l[start:i] in molecules_lm1:
molecules_sorted.append(molecule_l[start:i])
start = i
moleculess[l-1] = molecules_sorted
return moleculess
- 功能:基于对称性识别变量的层次分组(分子)。
- 工作原理:从单个变量(原子)开始,迭代测试对称性,构建更大组(分子)。
- 输出:返回列表的列表,每个子列表代表一层层次分组。
- 示例:对于model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2,返回[[[0],[1],...], [[0,1],[2,3],...], ...]。
(7) get_tree_node
def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
'''
get tree nodes
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
sep_th : float
threshold of separability
skip_test : bool
if True, don't test the property of each module (to save time)
Returns:
--------
arities : list of numbers
properties : list of strings
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> moleculess = get_molecule(model, x, verbose=False)
>>> get_tree_node(model, x, moleculess, skip_test=False)
'''
arities = []
properties = []
depth = len(moleculess) - 1
for l in range(depth):
molecules_l = copy.deepcopy(moleculess[l])
molecules_lp1 = copy.deepcopy(moleculess[l+1])
arity_l = []
property_l = []
for molecule in molecules_lp1:
start = 0
arity = 0
groups = []
for i in range(1,len(molecule)+1):
if molecule[start:i] in molecules_l:
groups.append(molecule[start:i])
start = i
arity += 1
arity_l.append(arity)
if arity == 1:
property = 'Id'
else:
property = ''
# test property
if skip_test:
gensep_bool = False
else:
gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
if gensep_bool:
property = 'GS'
if l == depth - 1:
if skip_test:
add_bool = False
mul_bool = False
else:
add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
if add_bool:
property = 'Add'
if mul_bool:
property = 'Mul'
property_l.append(property)
arities.append(arity_l)
properties.append(property_l)
return arities, properties
- 功能:从分子构建树结构,分配节点属性(如“Add”、“Mul”、“GS”、“Id”)。
- 工作原理:分析分子,确定每个节点的输入数(arity)和操作类型。
- 输出:返回两个列表:arities(每个节点的输入数)和properties(每个节点的操作类型)。
- 示例:基于上述分子,确定节点为加法或乘法。
(8) plot_tree
def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
'''
get tree graph
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
in_var : list of symbols
input variables
style : str
'tree' or 'box'
sym_th : float
threshold of symmetry
sep_th : float
threshold of separability
skip_sep_test : bool
if True, don't test the property of each module (to save time)
verbose : bool
Returns:
--------
a tree graph
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> plot_tree(model, x)
'''
moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)
n = x.shape[1]
var = None
in_vars = []
if in_var == None:
for ii in range(1, n + 1):
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
exec(f"in_vars.append(x{ii})")
elif type(var[0]) == Symbol:
in_vars = var
else:
in_vars = [sympy.symbols(var_) for var_ in var]
def flatten(xss):
return [x for xs in xss for x in xs]
def myrectangle(center_x, center_y, width_x, width_y):
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
depth = len(moleculess)
delta = 1/n
a = 0.3
b = 0.15
y0 = 0.5
# draw rectangles
for l in range(depth-1):
molecules = moleculess[l+1]
n_molecule = len(molecules)
centers = []
acc_arity = 0
for i in range(n_molecule):
start_id = len(flatten(molecules[:i]))
end_id = len(flatten(molecules[:i+1]))
center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
center_y = (l+1/2)*y0
width_x = (end_id - start_id - 1 + 2*a)*delta
width_y = 2*b
# add text (numbers) on rectangles
if style == 'box':
myrectangle(center_x, center_y, width_x, width_y)
plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
verticalalignment='center')
elif style == 'tree':
# if 'GS', no rectangle, n=arity tilted lines
# if 'Id', no rectangle, n=arity vertical lines
# if 'Add' or 'Mul'. rectangle, "+" or "x"
# if '', rectangle
property = properties[l][i]
if property == 'GS' or property == 'Add' or property == 'Mul':
color = 'blue'
arity = arities[l][i]
for j in range(arity):
if l == 0:
# x = (start_id + j) * delta + delta/2, center_x
# y = center_y - b, center_y + b
plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
else:
# x = last_centers[acc_arity:acc_arity+arity], center_x
# y = center_y - b, center_y + b
plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)
acc_arity += arity
if property == 'Add' or property == 'Mul':
if property == 'Add':
symbol = '+'
else:
symbol = '*'
plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
verticalalignment='center', color='red', fontsize=40)
if property == 'Id':
plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
if property == '':
myrectangle(center_x, center_y, width_x, width_y)
# connections to the next layer
plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
centers.append(center_x)
last_centers = copy.deepcopy(centers)
# connections from input variables to the first layer
for i in range(n):
x_ = (i + 1/2) * delta
# connections to the next layer
plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
plt.xlim(0,1)
#plt.ylim(0,1);
plt.axis('off');
plt.show()
- 功能:可视化函数的树结构。
- 工作原理:使用get_molecule和get_tree_node结果,绘制树或框图,节点表示操作,边表示变量组。
- 输出:显示函数结构的图形表示。
- 示例:绘制上述函数的层次结构图。
辅助函数
(9) batch_grad_normgrad
def batch_grad_normgrad(model, x, group, create_graph=False):
# x in shape (Batch, Length)
group_A = group
group_B = list(set(range(x.shape[1])) - set(group))
def jac(x):
input_grad = batch_jacobian(model, x, create_graph=True)
input_grad_A = input_grad[:,group_A]
norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
input_grad_A_normalized = input_grad_A/norm
return input_grad_A_normalized
def _jac_sum(x):
return jac(x).sum(dim=0)
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
- 计算标准化梯度的梯度,用于依赖性和对称性分析。
(10) get_dependence
def get_dependence(model, x, group):
group_A = group
group_B = list(set(range(x.shape[1])) - set(group))
grad_normgrad = batch_grad_normgrad(model, x, group=group)
std = torch.std(x, dim=0)
dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
dependence = torch.median(torch.abs(dependence), dim=0)[0]
return dependence
- 测量组与剩余变量的依赖性,用于对称性测试。
B. 完整代码
import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from sympy.utilities.lambdify import lambdify
from sklearn.cluster import AgglomerativeClustering
from .utils import batch_jacobian, batch_hessian
from functools import reduce
from kan.utils import batch_jacobian, batch_hessian
import copy
import matplotlib.pyplot as plt
import sympy
from sympy.printing import latex
def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
'''
detect function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
n_clusters : None or int
the number of clusters
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
results (dictionary)
Example1
--------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='add')
Example2
--------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> detect_separability(model, x, mode='mul')
'''
results = {}
if mode == 'add':
hessian = batch_hessian(model, x)
elif mode == 'mul':
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
std = torch.std(x, dim=0)
hessian_normalized = hessian * std[None,:] * std[:,None]
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
results['hessian'] = score_mat
dist_hard = (score_mat < score_th).float()
if isinstance(n_clusters, int):
n_cluster_try = [n_clusters, n_clusters]
elif isinstance(n_clusters, list):
n_cluster_try = n_clusters
else:
n_cluster_try = [1,x.shape[1]]
n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
for n_cluster in n_cluster_try:
clustering = AgglomerativeClustering(
metric='precomputed',
n_clusters=n_cluster,
linkage='complete',
).fit(dist_hard)
labels = clustering.labels_
groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
block_sum = torch.sum(torch.stack(blocks))
total_sum = torch.sum(score_mat)
residual_sum = total_sum - block_sum
residual_ratio = residual_sum / total_sum
if verbose == True:
print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
if residual_ratio < res_th:
results['n_groups'] = n_cluster
results['labels'] = list(labels)
results['groups'] = groups
if results['n_groups'] > 1:
print(f'{mode} separability detected')
else:
print(f'{mode} separability not detected')
return results
def batch_grad_normgrad(model, x, group, create_graph=False):
# x in shape (Batch, Length)
group_A = group
group_B = list(set(range(x.shape[1])) - set(group))
def jac(x):
input_grad = batch_jacobian(model, x, create_graph=True)
input_grad_A = input_grad[:,group_A]
norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
input_grad_A_normalized = input_grad_A/norm
return input_grad_A_normalized
def _jac_sum(x):
return jac(x).sum(dim=0)
return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
def get_dependence(model, x, group):
group_A = group
group_B = list(set(range(x.shape[1])) - set(group))
grad_normgrad = batch_grad_normgrad(model, x, group=group)
std = torch.std(x, dim=0)
dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
dependence = torch.median(torch.abs(dependence), dim=0)[0]
return dependence
def test_symmetry(model, x, group, dependence_th=1e-3):
'''
detect function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
group : a list of indices
dependence_th : float
threshold of dependence
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_symmetry(model, x, [1,2])) # True
>>> print(test_symmetry(model, x, [0,2])) # False
'''
if len(group) == x.shape[1] or len(group) == 0:
return True
dependence = get_dependence(model, x, group)
max_dependence = torch.max(dependence)
return max_dependence < dependence_th
def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
'''
test function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
>>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
'''
if mode == 'add':
hessian = batch_hessian(model, x)
elif mode == 'mul':
compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
std = torch.std(x, dim=0)
hessian_normalized = hessian * std[None,:] * std[:,None]
score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
sep_bool = True
# internal test
n_groups = len(groups)
for i in range(n_groups):
for j in range(i+1, n_groups):
sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
# external test
group_id = [x for xs in groups for x in xs]
nongroup_id = list(set(range(x.shape[1])) - set(group_id))
if len(nongroup_id) > 0 and len(group_id) > 0:
sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold
return sep_bool
def test_general_separability(model, x, groups, threshold=1e-2):
'''
test function separability
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
mode : str
mode = 'add' or mode = 'mul'
score_th : float
threshold of score
res_th : float
threshold of residue
bias : float
bias (for multiplicative separability)
verbose : bool
Returns:
--------
bool
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
>>> x = torch.normal(0,1,size=(100,3))
>>> print(test_general_separability(model, x, [[1],[0,2]])) # False
>>> print(test_general_separability(model, x, [[0],[1,2]])) # True
'''
grad = batch_jacobian(model, x)
gensep_bool = True
n_groups = len(groups)
for i in range(n_groups):
for j in range(i+1,n_groups):
group_A = groups[i]
group_B = groups[j]
for member_A in group_A:
for member_B in group_B:
def func(x):
grad = batch_jacobian(model, x, create_graph=True)
return grad[:,[member_B]]/grad[:,[member_A]]
# test if func is multiplicative separable
gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
return gensep_bool
def get_molecule(model, x, sym_th=1e-3, verbose=True):
'''
how variables are combined hierarchically
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
sym_th : float
threshold of symmetry
verbose : bool
Returns:
--------
list
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> get_molecule(model, x, verbose=False)
[[[0], [1], [2], [3], [4], [5], [6], [7]],
[[0, 1], [2, 3], [4, 5], [6, 7]],
[[0, 1, 2, 3], [4, 5, 6, 7]],
[[0, 1, 2, 3, 4, 5, 6, 7]]]
'''
n = x.shape[1]
atoms = [[i] for i in range(n)]
molecules = []
moleculess = [copy.deepcopy(atoms)]
already_full = False
n_layer = 0
last_n_molecule = n
while True:
pointer = 0
current_molecule = []
remove_atoms = []
n_atom = 0
while len(atoms) > 0:
# assemble molecule
atom = atoms[pointer]
if verbose:
print(current_molecule)
print(atom)
if len(current_molecule) == 0:
full = False
current_molecule += atom
remove_atoms.append(atom)
n_atom += 1
else:
# try assemble the atom to the molecule
if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
full = True
already_full = True
else:
full = False
if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
current_molecule += atom
remove_atoms.append(atom)
n_atom += 1
pointer += 1
if pointer == len(atoms) or full:
molecules.append(current_molecule)
if full:
molecules.append(atom)
remove_atoms.append(atom)
# remove molecules from atoms
for atom in remove_atoms:
atoms.remove(atom)
current_molecule = []
remove_atoms = []
pointer = 0
# if not making progress, terminate
if len(molecules) == last_n_molecule:
def flatten(xss):
return [x for xs in xss for x in xs]
moleculess.append([flatten(molecules)])
break
else:
moleculess.append(copy.deepcopy(molecules))
last_n_molecule = len(molecules)
if len(molecules) == 1:
break
atoms = molecules
molecules = []
n_layer += 1
#print(n_layer, atoms)
# sort
depth = len(moleculess) - 1
for l in list(range(depth,0,-1)):
molecules_sorted = []
molecules_l = moleculess[l]
molecules_lm1 = moleculess[l-1]
for molecule_l in molecules_l:
start = 0
for i in range(1,len(molecule_l)+1):
if molecule_l[start:i] in molecules_lm1:
molecules_sorted.append(molecule_l[start:i])
start = i
moleculess[l-1] = molecules_sorted
return moleculess
def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
'''
get tree nodes
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
sep_th : float
threshold of separability
skip_test : bool
if True, don't test the property of each module (to save time)
Returns:
--------
arities : list of numbers
properties : list of strings
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> moleculess = get_molecule(model, x, verbose=False)
>>> get_tree_node(model, x, moleculess, skip_test=False)
'''
arities = []
properties = []
depth = len(moleculess) - 1
for l in range(depth):
molecules_l = copy.deepcopy(moleculess[l])
molecules_lp1 = copy.deepcopy(moleculess[l+1])
arity_l = []
property_l = []
for molecule in molecules_lp1:
start = 0
arity = 0
groups = []
for i in range(1,len(molecule)+1):
if molecule[start:i] in molecules_l:
groups.append(molecule[start:i])
start = i
arity += 1
arity_l.append(arity)
if arity == 1:
property = 'Id'
else:
property = ''
# test property
if skip_test:
gensep_bool = False
else:
gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
if gensep_bool:
property = 'GS'
if l == depth - 1:
if skip_test:
add_bool = False
mul_bool = False
else:
add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
if add_bool:
property = 'Add'
if mul_bool:
property = 'Mul'
property_l.append(property)
arities.append(arity_l)
properties.append(property_l)
return arities, properties
def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
'''
get tree graph
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
in_var : list of symbols
input variables
style : str
'tree' or 'box'
sym_th : float
threshold of symmetry
sep_th : float
threshold of separability
skip_sep_test : bool
if True, don't test the property of each module (to save time)
verbose : bool
Returns:
--------
a tree graph
Example
-------
>>> from kan.hypothesis import *
>>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
>>> x = torch.normal(0,1,size=(100,8))
>>> plot_tree(model, x)
'''
moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)
n = x.shape[1]
var = None
in_vars = []
if in_var == None:
for ii in range(1, n + 1):
exec(f"x{ii} = sympy.Symbol('x_{ii}')")
exec(f"in_vars.append(x{ii})")
elif type(var[0]) == Symbol:
in_vars = var
else:
in_vars = [sympy.symbols(var_) for var_ in var]
def flatten(xss):
return [x for xs in xss for x in xs]
def myrectangle(center_x, center_y, width_x, width_y):
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
depth = len(moleculess)
delta = 1/n
a = 0.3
b = 0.15
y0 = 0.5
# draw rectangles
for l in range(depth-1):
molecules = moleculess[l+1]
n_molecule = len(molecules)
centers = []
acc_arity = 0
for i in range(n_molecule):
start_id = len(flatten(molecules[:i]))
end_id = len(flatten(molecules[:i+1]))
center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
center_y = (l+1/2)*y0
width_x = (end_id - start_id - 1 + 2*a)*delta
width_y = 2*b
# add text (numbers) on rectangles
if style == 'box':
myrectangle(center_x, center_y, width_x, width_y)
plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
verticalalignment='center')
elif style == 'tree':
# if 'GS', no rectangle, n=arity tilted lines
# if 'Id', no rectangle, n=arity vertical lines
# if 'Add' or 'Mul'. rectangle, "+" or "x"
# if '', rectangle
property = properties[l][i]
if property == 'GS' or property == 'Add' or property == 'Mul':
color = 'blue'
arity = arities[l][i]
for j in range(arity):
if l == 0:
# x = (start_id + j) * delta + delta/2, center_x
# y = center_y - b, center_y + b
plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
else:
# x = last_centers[acc_arity:acc_arity+arity], center_x
# y = center_y - b, center_y + b
plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)
acc_arity += arity
if property == 'Add' or property == 'Mul':
if property == 'Add':
symbol = '+'
else:
symbol = '*'
plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
verticalalignment='center', color='red', fontsize=40)
if property == 'Id':
plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
if property == '':
myrectangle(center_x, center_y, width_x, width_y)
# connections to the next layer
plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
centers.append(center_x)
last_centers = copy.deepcopy(centers)
# connections from input variables to the first layer
for i in range(n):
x_ = (i + 1/2) * delta
# connections to the next layer
plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
plt.xlim(0,1)
#plt.ylim(0,1);
plt.axis('off');
plt.show()
def test_symmetry_var(model, x, input_vars, symmetry_var):
'''
test symmetry
Args:
-----
model : MultKAN, MLP or python function
x : 2D torch.float
inputs
input_vars : list of sympy symbols
symmetry_var : sympy expression
Returns:
--------
cosine similarity
Example
-------
>>> from kan.hypothesis import *
>>> from sympy import *
>>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
>>> x = torch.normal(0,1,size=(100,8))
>>> input_vars = a, b, c = symbols('a b c')
>>> symmetry_var = b + c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
>>> symmetry_var = b * c
>>> test_symmetry_var(model, x, input_vars, symmetry_var);
'''
orig_vars = input_vars
sym_var = symmetry_var
# gradients wrt to input (model)
input_grad = batch_jacobian(model, x)
# gradients wrt to input (symmetry var)
func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function
func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
sym_grad = batch_jacobian(func2, x)
# get id
idx = []
sym_symbols = list(sym_var.free_symbols)
for sym_symbol in sym_symbols:
for j in range(len(orig_vars)):
if sym_symbol == orig_vars[j]:
idx.append(j)
input_grad_part = input_grad[:,idx]
sym_grad_part = sym_grad[:,idx]
cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
ratio = torch.sum(cossim > 0.9)/len(cossim)
print(f'{100*ratio}% data have more than 0.9 cosine similarity')
if ratio > 0.9:
print('suggesting symmetry')
else:
print('not suggesting symmetry')
return cossim
四、总结与思考
KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。
1. KAN网络架构
-
关键设计:可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。
-
示例结构:输入层 → 隐藏层:每个输入节点通过单变量函数
连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数
组合得到输出。
2. 优势与特点
-
高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。
-
可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。
-
灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。
3. 挑战与局限
-
计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。
-
泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。
-
训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。
4. 应用场景
-
科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。
-
可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。
-
符号回归:从数据中自动发现数学公式(如物理定律)。
5. 与传统MLP的对比
6. 研究进展
-
近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。
-
开源实现:已有PyTorch等框架的初步实现。
【作者声明】
本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。
【关注我们】
如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!
更多推荐
所有评论(0)