一、引言

        KAN神经网络(Kolmogorov–Arnold Networks)是一种基于Kolmogorov-Arnold表示定理的新型神经网络架构。该定理指出,任何多元连续函数都可以表示为有限个单变量函数的组合。与传统多层感知机(MLP)不同,KAN通过可学习的激活函数和结构化网络设计,在函数逼近效率和可解释性上展现出潜力。


二、技术与原理简介

        1.Kolmogorov-Arnold 表示定理

         Kolmogorov-Arnold 表示定理指出,如果 是有界域上的多元连续函数,那么它可以写为单个变量的连续函数的有限组合,以及加法的二进制运算。更具体地说,对于 光滑ff:[0,1]^{^{n}}\rightarrow \mathbb{R}

f \left( x \right)=f \left( x_{1}, \cdots,x_{n} \right)= \sum_{q=1}^{2n+1} \Phi_{q} \left( \sum_{p=1}^{n} \phi_{q,p} \left( x_{p} \right) \right)

        其中 和 。从某种意义上说,他们表明唯一真正的多元函数是加法,因为所有其他函数都可以使用单变量函数和 sum 来编写。然而,这个 2 层宽度 - Kolmogorov-Arnold 表示可能不是平滑的由于其表达能力有限。我们通过以下方式增强它的表达能力将其推广到任意深度和宽度。\boldsymbol{\phi_{q,p}:[0,1]\to\mathbb{R}}\boldsymbol{\Phi_{q}:\mathbb{R}\to\mathbb{R}(2n+1)}

        2.Kolmogorov-Arnold 网络 (KAN)

        Kolmogorov-Arnold 表示可以写成矩阵形式

f(x)=\mathbf{\Phi_{out}}\mathsf{o}\mathbf{\Phi_{in}}\mathsf{o}{}x

其中

\mathbf{\Phi}_{\mathrm{in}}=\begin{pmatrix}\phi_{1,1}(\cdot)&\cdots&\phi_{1,n }(\cdot)\\ \vdots&&\vdots\\ \phi_{2n+1,1}(\cdot)&\cdots&\phi_{2n+1,n}(\cdot)\end{pmatrix}

\quad\mathbf{ \Phi}_{\mathrm{out}}=\left(\Phi_{1}(\cdot)\quad\cdots\quad\Phi_{2n+1}(\cdot)\right)

        我们注意到 和 都是以下函数矩阵(包含输入和输出)的特例,我们称之为 Kolmogorov-Arnold 层:\mathbf{\Phi_{in}} \mathbf{\Phi_{out}} \mathbf{\Phi_{n_{in}n_{out}}}

其中\boldsymbol{n_{\text{in}}=n,n_{\text{out}}=2n+1\Phi_{\text{out}}n_{\text{in}}=2n+1,n_{\text{out}}=1}

        定义层后,我们可以构造一个 Kolmogorov-Arnold 网络只需堆叠层!假设我们有层,层的形状为 。那么整个网络是Ll^{th} \Phi_{l} \left( n_{l+1},n_{l} \right)

\mathbf{KAN(x)}=\mathbf{\Phi_{L-1}}\circ\cdots\circ\mathbf{\Phi_{1}}\circ \mathbf{\Phi_{0}}\circ\mathbf{x}

        相反,多层感知器由线性层和非线错:\mathbf{W}_{l^{\sigma}}

\text{MLP}(\mathbf{x})=\mathbf{W}_{\textit{L-1}}\circ\sigma\circ\cdots\circ \mathbf{W}_{1}\circ\sigma\circ\mathbf{W}_{0}\circ\mathbf{x}

        KAN 可以很容易地可视化。(1) KAN 只是 KAN 层的堆栈。(2) 每个 KAN 层都可以可视化为一个全连接层,每个边缘上都有一个1D 函数。


三、代码详解

        整个代码主要围绕模型输入变量间的相互作用、可分性和对称性检测展开,并进一步构建变量组合的层次结构(树形图),以解释模型内部结构及其拟合机制。每个函数都利用数值计算(例如梯度和 Hessian)及聚类或统计方法对模型进行深入分析。

        A. 代码详解

        可分离性检测函数

        (1) detect_separability
def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
    '''
        detect function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            n_clusters : None or int
                the number of clusters
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            results (dictionary)
            
        Example1
        --------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> detect_separability(model, x, mode='add')
        
        Example2
        --------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> detect_separability(model, x, mode='mul')
    '''
    results = {}
    
    if mode == 'add':
        hessian = batch_hessian(model, x)
    elif mode == 'mul':
        compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
        hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
        
    std = torch.std(x, dim=0)
    hessian_normalized = hessian * std[None,:] * std[:,None]
    score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
    results['hessian'] = score_mat

    dist_hard = (score_mat < score_th).float()
    
    if isinstance(n_clusters, int):
        n_cluster_try = [n_clusters, n_clusters]
    elif isinstance(n_clusters, list):
        n_cluster_try = n_clusters
    else:
        n_cluster_try = [1,x.shape[1]]
        
    n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
    
    for n_cluster in n_cluster_try:
        
        clustering = AgglomerativeClustering(
          metric='precomputed',
          n_clusters=n_cluster,
          linkage='complete',
        ).fit(dist_hard)

        labels = clustering.labels_

        groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
        blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
        block_sum = torch.sum(torch.stack(blocks))
        total_sum = torch.sum(score_mat)
        residual_sum = total_sum - block_sum
        residual_ratio = residual_sum / total_sum
        
        if verbose == True:
            print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
        
        if residual_ratio < res_th:
            results['n_groups'] = n_cluster
            results['labels'] = list(labels)
            results['groups'] = groups
            
    if results['n_groups'] > 1:
        print(f'{mode} separability detected')
    else:
        print(f'{mode} separability not detected')
            
    return results
  • 功能:检测函数是否为加法可分离(mode='add')或乘法可分离(mode='mul')。
  • 工作原理:计算模型的Hessian(二阶导数),并通过分析Hessian的标准化版本确定函数是否可以分解为变量组的独立部分。对于加法可分离,直接检查Hessian;对于乘法可分离,通过对数变换检查日志空间中的可分离性。
  • 输出:返回一个字典,包含Hessian得分、组数(如可分离)及组标签。
  • 示例:对于函数f(x)=x{_{1}}^{2}+exp(x_{2}+x^{_{3}}),可以检测其加法可分离性。
        (2) test_separability
def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
    '''
        test function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
        >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
    '''
    if mode == 'add':
        hessian = batch_hessian(model, x)
    elif mode == 'mul':
        compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
        hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)

    std = torch.std(x, dim=0)
    hessian_normalized = hessian * std[None,:] * std[:,None]
    score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
    
    sep_bool = True
    
    # internal test
    n_groups = len(groups)
    for i in range(n_groups):
        for j in range(i+1, n_groups):
            sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
            
    # external test
    group_id = [x for xs in groups for x in xs]
    nongroup_id = list(set(range(x.shape[1])) - set(group_id))
    if len(nongroup_id) > 0 and len(group_id) > 0:
        sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold 

    return sep_bool
  • 功能:给定变量组,测试函数是否可分离。
  • 工作原理:计算Hessian,检查组间块是否低于阈值,表明可分离。对于加法和乘法模式分别处理。
  • 输出:返回布尔值,表示是否可分离。
  • 示例:测试f(x)={x_{1}}^{2}*(x_{2}+x_{3})在组[[0],[1,2]]下的乘法可分离性。
        (3) test_general_separability
def test_general_separability(model, x, groups, threshold=1e-2):
    '''
        test function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_general_separability(model, x, [[1],[0,2]])) # False
        >>> print(test_general_separability(model, x, [[0],[1,2]])) # True
    '''
    grad = batch_jacobian(model, x)

    gensep_bool = True

    n_groups = len(groups)
    for i in range(n_groups):
        for j in range(i+1,n_groups):
            group_A = groups[i]
            group_B = groups[j]
            for member_A in group_A:
                for member_B in group_B:
                    def func(x):
                        grad = batch_jacobian(model, x, create_graph=True)
                        return grad[:,[member_B]]/grad[:,[member_A]]
                    # test if func is multiplicative separable
                    gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
    return gensep_bool
  • 功能:测试更一般的可分离性,不限于加法或乘法。
  • 工作原理:计算Jacobian(一阶导数),检查组间梯度比是否乘法可分离。
  • 输出:返回布尔值,表示是否一般可分离。
  • 示例:测试f(x)={x_{1}}^{2}*(x_{2}+x_{3})^{2}在组[[0],[1,2]]下的可分离性。

        对称性测试函数

        (4) test_symmetry
def test_symmetry(model, x, group, dependence_th=1e-3):
    '''
        detect function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            group : a list of indices
            dependence_th : float
                threshold of dependence

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_symmetry(model, x, [1,2])) # True
        >>> print(test_symmetry(model, x, [0,2])) # False
    '''
    if len(group) == x.shape[1] or len(group) == 0:
        return True
    
    dependence = get_dependence(model, x, group)
    max_dependence = torch.max(dependence)
    return max_dependence < dependence_th
  • 功能:测试函数在给定变量组下的对称性。
  • 工作原理:计算组与剩余变量的依赖性,检查最大依赖是否低于阈值。
  • 输出:返回布尔值,表示是否对称。
  • 示例:测试f(x)={x_{1}}*(x_{2}+x_{3})在组[1,2]下的对称性。
        (5) test_symmetry_var
def test_symmetry_var(model, x, input_vars, symmetry_var):
    '''
        test symmetry
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            input_vars : list of sympy symbols
            symmetry_var : sympy expression

        Returns:
        --------
            cosine similarity
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> from sympy import *
        >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,8))
        >>> input_vars = a, b, c = symbols('a b c')
        >>> symmetry_var = b + c
        >>> test_symmetry_var(model, x, input_vars, symmetry_var);
        >>> symmetry_var = b * c
        >>> test_symmetry_var(model, x, input_vars, symmetry_var);
    '''
    orig_vars = input_vars
    sym_var = symmetry_var
    
    # gradients wrt to input (model)
    input_grad = batch_jacobian(model, x)

    # gradients wrt to input (symmetry var)
    func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function

    func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
    sym_grad = batch_jacobian(func2, x)
    
    # get id
    idx = []
    sym_symbols = list(sym_var.free_symbols)
    for sym_symbol in sym_symbols:
        for j in range(len(orig_vars)):
            if sym_symbol == orig_vars[j]:
                idx.append(j)
                
    input_grad_part = input_grad[:,idx]
    sym_grad_part = sym_grad[:,idx]
    
    cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
    
    ratio = torch.sum(cossim > 0.9)/len(cossim)
    
    print(f'{100*ratio}% data have more than 0.9 cosine similarity')
    if ratio > 0.9:
        print('suggesting symmetry')
    else:
        print('not suggesting symmetry')
        
    return cossim
  • 功能:测试函数相对于符号表达式的对称性。
  • 工作原理:计算模型和符号表达式的梯度,测量余弦相似度,确定对称性。
  • 输出:打印高相似度数据百分比,并建议是否对称。
  • 示例:测试f(x)={x_{1}}*(x_{2}+x_{3})相对于(x_{2}+x_{3})​的对称性。

        层次结构分析函数

        (6) get_molecule
def get_molecule(model, x, sym_th=1e-3, verbose=True):
    '''
        how variables are combined hierarchically
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            sym_th : float
                threshold of symmetry
            verbose : bool

        Returns:
        --------
            list
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> get_molecule(model, x, verbose=False)
        [[[0], [1], [2], [3], [4], [5], [6], [7]],
         [[0, 1], [2, 3], [4, 5], [6, 7]],
         [[0, 1, 2, 3], [4, 5, 6, 7]],
         [[0, 1, 2, 3, 4, 5, 6, 7]]]
    '''
    n = x.shape[1]
    atoms = [[i] for i in range(n)]
    molecules = []
    moleculess = [copy.deepcopy(atoms)]
    already_full = False
    n_layer = 0
    last_n_molecule = n
    
    while True:
        

        pointer = 0
        current_molecule = []
        remove_atoms = []
        n_atom = 0

        while len(atoms) > 0:
            
            # assemble molecule
            atom = atoms[pointer]
            if verbose:
                print(current_molecule)
                print(atom)

            if len(current_molecule) == 0:
                full = False
                current_molecule += atom
                remove_atoms.append(atom)
                n_atom += 1
            else:
                # try assemble the atom to the molecule
                if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
                    full = True
                    already_full = True
                else:
                    full = False
                    if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
                        current_molecule += atom
                        remove_atoms.append(atom)
                        n_atom += 1

            pointer += 1
            
            if pointer == len(atoms) or full:
                molecules.append(current_molecule)
                if full:
                    molecules.append(atom)
                    remove_atoms.append(atom)
                # remove molecules from atoms
                for atom in remove_atoms:
                    atoms.remove(atom)
                current_molecule = []
                remove_atoms = []
                pointer = 0
                
        # if not making progress, terminate
        if len(molecules) == last_n_molecule:
            def flatten(xss):
                return [x for xs in xss for x in xs]
            moleculess.append([flatten(molecules)])
            break
        else:
            moleculess.append(copy.deepcopy(molecules))
            
        last_n_molecule = len(molecules)

        if len(molecules) == 1:
            break

        atoms = molecules
        molecules = []
        
        n_layer += 1
        
        #print(n_layer, atoms)
        
        
    # sort
    depth = len(moleculess) - 1

    for l in list(range(depth,0,-1)):

        molecules_sorted = []
        molecules_l = moleculess[l]
        molecules_lm1 = moleculess[l-1]


        for molecule_l in molecules_l:
            start = 0
            for i in range(1,len(molecule_l)+1):
                if molecule_l[start:i] in molecules_lm1:

                    molecules_sorted.append(molecule_l[start:i])
                    start = i

        moleculess[l-1] = molecules_sorted

    return moleculess
  • 功能:基于对称性识别变量的层次分组(分子)。
  • 工作原理:从单个变量(原子)开始,迭代测试对称性,构建更大组(分子)。
  • 输出:返回列表的列表,每个子列表代表一层层次分组。
  • 示例:对于model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2,返回[[[0],[1],...], [[0,1],[2,3],...], ...]。
        (7) get_tree_node
def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
    '''
        get tree nodes
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            sep_th : float
                threshold of separability
            skip_test : bool
                if True, don't test the property of each module (to save time)

        Returns:
        --------
            arities : list of numbers
            properties : list of strings
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> moleculess = get_molecule(model, x, verbose=False)
        >>> get_tree_node(model, x, moleculess, skip_test=False)
    '''
    arities = []
    properties = []
    
    depth = len(moleculess) - 1

    for l in range(depth):
        molecules_l = copy.deepcopy(moleculess[l])
        molecules_lp1 = copy.deepcopy(moleculess[l+1])
        arity_l = []
        property_l = []

        for molecule in molecules_lp1:
            start = 0
            arity = 0
            groups = []
            for i in range(1,len(molecule)+1):
                if molecule[start:i] in molecules_l:
                    groups.append(molecule[start:i])
                    start = i
                    arity += 1
            arity_l.append(arity)

            if arity == 1:
                property = 'Id'
            else:
                property = ''
                # test property
                if skip_test:
                    gensep_bool = False
                else:
                    gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
                    
                if gensep_bool:
                    property = 'GS'
                if l == depth - 1:
                    if skip_test:
                        add_bool = False
                        mul_bool = False
                    else:
                        add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
                        mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
                    if add_bool:
                        property = 'Add'
                    if mul_bool:
                        property = 'Mul'


            property_l.append(property)


        arities.append(arity_l)
        properties.append(property_l)
        
    return arities, properties
  • 功能:从分子构建树结构,分配节点属性(如“Add”、“Mul”、“GS”、“Id”)。
  • 工作原理:分析分子,确定每个节点的输入数(arity)和操作类型。
  • 输出:返回两个列表:arities(每个节点的输入数)和properties(每个节点的操作类型)。
  • 示例:基于上述分子,确定节点为加法或乘法。
        (8) plot_tree
def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
    '''
        get tree graph
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            in_var : list of symbols
                input variables
            style : str
                'tree' or 'box'
            sym_th : float
                threshold of symmetry
            sep_th : float
                threshold of separability
            skip_sep_test : bool
                if True, don't test the property of each module (to save time)
            verbose : bool

        Returns:
        --------
            a tree graph
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> plot_tree(model, x)
    '''
    moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
    arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)

    n = x.shape[1]
    var = None

    in_vars = []

    if in_var == None:
        for ii in range(1, n + 1):
            exec(f"x{ii} = sympy.Symbol('x_{ii}')")
            exec(f"in_vars.append(x{ii})")
    elif type(var[0]) == Symbol:
        in_vars = var
    else:
        in_vars = [sympy.symbols(var_) for var_ in var]
        

    def flatten(xss):
        return [x for xs in xss for x in xs]

    def myrectangle(center_x, center_y, width_x, width_y):
        plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
        plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
        plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
        plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left

    depth = len(moleculess)

    delta = 1/n
    a = 0.3
    b = 0.15
    y0 = 0.5


    # draw rectangles
    for l in range(depth-1):
        molecules = moleculess[l+1]
        n_molecule = len(molecules)

        centers = []

        acc_arity = 0

        for i in range(n_molecule):
            start_id = len(flatten(molecules[:i]))
            end_id = len(flatten(molecules[:i+1]))

            center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
            center_y = (l+1/2)*y0
            width_x = (end_id - start_id - 1 + 2*a)*delta
            width_y = 2*b

            # add text (numbers) on rectangles
            if style == 'box':
                myrectangle(center_x, center_y, width_x, width_y)
                plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
                verticalalignment='center')
            elif style == 'tree':
                # if 'GS', no rectangle, n=arity tilted lines
                # if 'Id', no rectangle, n=arity vertical lines
                # if 'Add' or 'Mul'. rectangle, "+" or "x"
                # if '', rectangle
                property = properties[l][i]
                if property == 'GS' or property == 'Add' or property == 'Mul':
                    color = 'blue'
                    arity = arities[l][i]
                    for j in range(arity):

                        if l == 0:
                            # x = (start_id + j) * delta + delta/2, center_x
                            # y = center_y - b, center_y + b
                            plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
                        else:
                            # x = last_centers[acc_arity:acc_arity+arity], center_x
                            # y = center_y - b, center_y + b
                            plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)

                    acc_arity += arity

                    if property == 'Add' or property == 'Mul':
                        if property == 'Add':
                            symbol = '+'
                        else:
                            symbol = '*'

                        plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
                verticalalignment='center', color='red', fontsize=40)
                if property == 'Id':
                    plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
                    
                if property == '':
                    myrectangle(center_x, center_y, width_x, width_y)



            # connections to the next layer
            plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
            centers.append(center_x)
        last_centers = copy.deepcopy(centers)

    # connections from input variables to the first layer
    for i in range(n):
        x_ = (i + 1/2) * delta
        # connections to the next layer
        plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
        plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
    plt.xlim(0,1)
    #plt.ylim(0,1);
    plt.axis('off');
    plt.show()
  • 功能:可视化函数的树结构。
  • 工作原理:使用get_molecule和get_tree_node结果,绘制树或框图,节点表示操作,边表示变量组。
  • 输出:显示函数结构的图形表示。
  • 示例:绘制上述函数的层次结构图。

        辅助函数

        (9) batch_grad_normgrad
def batch_grad_normgrad(model, x, group, create_graph=False):
    # x in shape (Batch, Length)
    group_A = group
    group_B = list(set(range(x.shape[1])) - set(group))

    def jac(x):
        input_grad = batch_jacobian(model, x, create_graph=True)
        input_grad_A = input_grad[:,group_A]
        norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
        input_grad_A_normalized = input_grad_A/norm
        return input_grad_A_normalized
    
    def _jac_sum(x):
        return jac(x).sum(dim=0)
    
    return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]
  • 计算标准化梯度的梯度,用于依赖性和对称性分析。
        (10) get_dependence
def get_dependence(model, x, group):
    group_A = group
    group_B = list(set(range(x.shape[1])) - set(group))
    grad_normgrad = batch_grad_normgrad(model, x, group=group)
    std = torch.std(x, dim=0)
    dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
    dependence = torch.median(torch.abs(dependence), dim=0)[0]
    return dependence
  • 测量组与剩余变量的依赖性,用于对称性测试。

        B. 完整代码

import numpy as np
import torch
from sklearn.linear_model import LinearRegression
from sympy.utilities.lambdify import lambdify
from sklearn.cluster import AgglomerativeClustering
from .utils import batch_jacobian, batch_hessian
from functools import reduce
from kan.utils import batch_jacobian, batch_hessian
import copy 
import matplotlib.pyplot as plt
import sympy 
from sympy.printing import latex


def detect_separability(model, x, mode='add', score_th=1e-2, res_th=1e-2, n_clusters=None, bias=0., verbose=False):
    '''
        detect function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            n_clusters : None or int
                the number of clusters
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            results (dictionary)
            
        Example1
        --------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 + torch.exp(x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> detect_separability(model, x, mode='add')
        
        Example2
        --------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> detect_separability(model, x, mode='mul')
    '''
    results = {}
    
    if mode == 'add':
        hessian = batch_hessian(model, x)
    elif mode == 'mul':
        compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
        hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)
        
    std = torch.std(x, dim=0)
    hessian_normalized = hessian * std[None,:] * std[:,None]
    score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
    results['hessian'] = score_mat

    dist_hard = (score_mat < score_th).float()
    
    if isinstance(n_clusters, int):
        n_cluster_try = [n_clusters, n_clusters]
    elif isinstance(n_clusters, list):
        n_cluster_try = n_clusters
    else:
        n_cluster_try = [1,x.shape[1]]
        
    n_cluster_try = list(range(n_cluster_try[0], n_cluster_try[1]+1))
    
    for n_cluster in n_cluster_try:
        
        clustering = AgglomerativeClustering(
          metric='precomputed',
          n_clusters=n_cluster,
          linkage='complete',
        ).fit(dist_hard)

        labels = clustering.labels_

        groups = [list(np.where(labels == i)[0]) for i in range(n_cluster)]
        blocks = [torch.sum(score_mat[groups[i]][:,groups[i]]) for i in range(n_cluster)]
        block_sum = torch.sum(torch.stack(blocks))
        total_sum = torch.sum(score_mat)
        residual_sum = total_sum - block_sum
        residual_ratio = residual_sum / total_sum
        
        if verbose == True:
            print(f'n_group={n_cluster}, residual_ratio={residual_ratio}')
        
        if residual_ratio < res_th:
            results['n_groups'] = n_cluster
            results['labels'] = list(labels)
            results['groups'] = groups
            
    if results['n_groups'] > 1:
        print(f'{mode} separability detected')
    else:
        print(f'{mode} separability not detected')
            
    return results


def batch_grad_normgrad(model, x, group, create_graph=False):
    # x in shape (Batch, Length)
    group_A = group
    group_B = list(set(range(x.shape[1])) - set(group))

    def jac(x):
        input_grad = batch_jacobian(model, x, create_graph=True)
        input_grad_A = input_grad[:,group_A]
        norm = torch.norm(input_grad_A, dim=1, keepdim=True) + 1e-6
        input_grad_A_normalized = input_grad_A/norm
        return input_grad_A_normalized
    
    def _jac_sum(x):
        return jac(x).sum(dim=0)
    
    return torch.autograd.functional.jacobian(_jac_sum, x, create_graph=create_graph).permute(1,0,2)[:,:,group_B]


def get_dependence(model, x, group):
    group_A = group
    group_B = list(set(range(x.shape[1])) - set(group))
    grad_normgrad = batch_grad_normgrad(model, x, group=group)
    std = torch.std(x, dim=0)
    dependence = grad_normgrad * std[None,group_A,None] * std[None,None,group_B]
    dependence = torch.median(torch.abs(dependence), dim=0)[0]
    return dependence

def test_symmetry(model, x, group, dependence_th=1e-3):
    '''
        detect function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            group : a list of indices
            dependence_th : float
                threshold of dependence

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_symmetry(model, x, [1,2])) # True
        >>> print(test_symmetry(model, x, [0,2])) # False
    '''
    if len(group) == x.shape[1] or len(group) == 0:
        return True
    
    dependence = get_dependence(model, x, group)
    max_dependence = torch.max(dependence)
    return max_dependence < dependence_th


def test_separability(model, x, groups, mode='add', threshold=1e-2, bias=0):
    '''
        test function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]+x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_separability(model, x, [[0],[1,2]], mode='mul')) # True
        >>> print(test_separability(model, x, [[0],[1,2]], mode='add')) # False
    '''
    if mode == 'add':
        hessian = batch_hessian(model, x)
    elif mode == 'mul':
        compose = lambda *F: reduce(lambda f, g: lambda x: f(g(x)), F)
        hessian = batch_hessian(compose(torch.log, torch.abs, lambda x: x+bias, model), x)

    std = torch.std(x, dim=0)
    hessian_normalized = hessian * std[None,:] * std[:,None]
    score_mat = torch.median(torch.abs(hessian_normalized), dim=0)[0]
    
    sep_bool = True
    
    # internal test
    n_groups = len(groups)
    for i in range(n_groups):
        for j in range(i+1, n_groups):
            sep_bool *= torch.max(score_mat[groups[i]][:,groups[j]]) < threshold
            
    # external test
    group_id = [x for xs in groups for x in xs]
    nongroup_id = list(set(range(x.shape[1])) - set(group_id))
    if len(nongroup_id) > 0 and len(group_id) > 0:
        sep_bool *= torch.max(score_mat[group_id][:,nongroup_id]) < threshold 

    return sep_bool

def test_general_separability(model, x, groups, threshold=1e-2):
    '''
        test function separability
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            mode : str
                mode = 'add' or mode = 'mul'
            score_th : float
                threshold of score
            res_th : float
                threshold of residue
            bias : float
                bias (for multiplicative separability)
            verbose : bool

        Returns:
        --------
            bool
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: x[:,[0]] ** 2 * (x[:,[1]]**2+x[:,[2]]**2)**2
        >>> x = torch.normal(0,1,size=(100,3))
        >>> print(test_general_separability(model, x, [[1],[0,2]])) # False
        >>> print(test_general_separability(model, x, [[0],[1,2]])) # True
    '''
    grad = batch_jacobian(model, x)

    gensep_bool = True

    n_groups = len(groups)
    for i in range(n_groups):
        for j in range(i+1,n_groups):
            group_A = groups[i]
            group_B = groups[j]
            for member_A in group_A:
                for member_B in group_B:
                    def func(x):
                        grad = batch_jacobian(model, x, create_graph=True)
                        return grad[:,[member_B]]/grad[:,[member_A]]
                    # test if func is multiplicative separable
                    gensep_bool *= test_separability(func, x, groups, mode='mul', threshold=threshold)
    return gensep_bool


def get_molecule(model, x, sym_th=1e-3, verbose=True):
    '''
        how variables are combined hierarchically
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            sym_th : float
                threshold of symmetry
            verbose : bool

        Returns:
        --------
            list
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> get_molecule(model, x, verbose=False)
        [[[0], [1], [2], [3], [4], [5], [6], [7]],
         [[0, 1], [2, 3], [4, 5], [6, 7]],
         [[0, 1, 2, 3], [4, 5, 6, 7]],
         [[0, 1, 2, 3, 4, 5, 6, 7]]]
    '''
    n = x.shape[1]
    atoms = [[i] for i in range(n)]
    molecules = []
    moleculess = [copy.deepcopy(atoms)]
    already_full = False
    n_layer = 0
    last_n_molecule = n
    
    while True:
        

        pointer = 0
        current_molecule = []
        remove_atoms = []
        n_atom = 0

        while len(atoms) > 0:
            
            # assemble molecule
            atom = atoms[pointer]
            if verbose:
                print(current_molecule)
                print(atom)

            if len(current_molecule) == 0:
                full = False
                current_molecule += atom
                remove_atoms.append(atom)
                n_atom += 1
            else:
                # try assemble the atom to the molecule
                if len(current_molecule+atom) == x.shape[1] and already_full == False and n_atom > 1 and n_layer > 0:
                    full = True
                    already_full = True
                else:
                    full = False
                    if test_symmetry(model, x, current_molecule+atom, dependence_th=sym_th):
                        current_molecule += atom
                        remove_atoms.append(atom)
                        n_atom += 1

            pointer += 1
            
            if pointer == len(atoms) or full:
                molecules.append(current_molecule)
                if full:
                    molecules.append(atom)
                    remove_atoms.append(atom)
                # remove molecules from atoms
                for atom in remove_atoms:
                    atoms.remove(atom)
                current_molecule = []
                remove_atoms = []
                pointer = 0
                
        # if not making progress, terminate
        if len(molecules) == last_n_molecule:
            def flatten(xss):
                return [x for xs in xss for x in xs]
            moleculess.append([flatten(molecules)])
            break
        else:
            moleculess.append(copy.deepcopy(molecules))
            
        last_n_molecule = len(molecules)

        if len(molecules) == 1:
            break

        atoms = molecules
        molecules = []
        
        n_layer += 1
        
        #print(n_layer, atoms)
        
        
    # sort
    depth = len(moleculess) - 1

    for l in list(range(depth,0,-1)):

        molecules_sorted = []
        molecules_l = moleculess[l]
        molecules_lm1 = moleculess[l-1]


        for molecule_l in molecules_l:
            start = 0
            for i in range(1,len(molecule_l)+1):
                if molecule_l[start:i] in molecules_lm1:

                    molecules_sorted.append(molecule_l[start:i])
                    start = i

        moleculess[l-1] = molecules_sorted

    return moleculess


def get_tree_node(model, x, moleculess, sep_th=1e-2, skip_test=True):
    '''
        get tree nodes
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            sep_th : float
                threshold of separability
            skip_test : bool
                if True, don't test the property of each module (to save time)

        Returns:
        --------
            arities : list of numbers
            properties : list of strings
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> moleculess = get_molecule(model, x, verbose=False)
        >>> get_tree_node(model, x, moleculess, skip_test=False)
    '''
    arities = []
    properties = []
    
    depth = len(moleculess) - 1

    for l in range(depth):
        molecules_l = copy.deepcopy(moleculess[l])
        molecules_lp1 = copy.deepcopy(moleculess[l+1])
        arity_l = []
        property_l = []

        for molecule in molecules_lp1:
            start = 0
            arity = 0
            groups = []
            for i in range(1,len(molecule)+1):
                if molecule[start:i] in molecules_l:
                    groups.append(molecule[start:i])
                    start = i
                    arity += 1
            arity_l.append(arity)

            if arity == 1:
                property = 'Id'
            else:
                property = ''
                # test property
                if skip_test:
                    gensep_bool = False
                else:
                    gensep_bool = test_general_separability(model, x, groups, threshold=sep_th)
                    
                if gensep_bool:
                    property = 'GS'
                if l == depth - 1:
                    if skip_test:
                        add_bool = False
                        mul_bool = False
                    else:
                        add_bool = test_separability(model, x, groups, mode='add', threshold=sep_th)
                        mul_bool = test_separability(model, x, groups, mode='mul', threshold=sep_th)
                    if add_bool:
                        property = 'Add'
                    if mul_bool:
                        property = 'Mul'


            property_l.append(property)


        arities.append(arity_l)
        properties.append(property_l)
        
    return arities, properties


def plot_tree(model, x, in_var=None, style='tree', sym_th=1e-3, sep_th=1e-1, skip_sep_test=False, verbose=False):
    '''
        get tree graph
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            in_var : list of symbols
                input variables
            style : str
                'tree' or 'box'
            sym_th : float
                threshold of symmetry
            sep_th : float
                threshold of separability
            skip_sep_test : bool
                if True, don't test the property of each module (to save time)
            verbose : bool

        Returns:
        --------
            a tree graph
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> model = lambda x: ((x[:,[0]] ** 2 + x[:,[1]] ** 2) ** 2 + (x[:,[2]] ** 2 + x[:,[3]] ** 2) ** 2) ** 2 + ((x[:,[4]] ** 2 + x[:,[5]] ** 2) ** 2 + (x[:,[6]] ** 2 + x[:,[7]] ** 2) ** 2) ** 2
        >>> x = torch.normal(0,1,size=(100,8))
        >>> plot_tree(model, x)
    '''
    moleculess = get_molecule(model, x, sym_th=sym_th, verbose=verbose)
    arities, properties = get_tree_node(model, x, moleculess, sep_th=sep_th, skip_test=skip_sep_test)

    n = x.shape[1]
    var = None

    in_vars = []

    if in_var == None:
        for ii in range(1, n + 1):
            exec(f"x{ii} = sympy.Symbol('x_{ii}')")
            exec(f"in_vars.append(x{ii})")
    elif type(var[0]) == Symbol:
        in_vars = var
    else:
        in_vars = [sympy.symbols(var_) for var_ in var]
        

    def flatten(xss):
        return [x for xs in xss for x in xs]

    def myrectangle(center_x, center_y, width_x, width_y):
        plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y + width_y/2, center_y + width_y/2], color='k') # up
        plt.plot([center_x - width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y - width_y/2], color='k') # down
        plt.plot([center_x - width_x/2, center_x - width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left
        plt.plot([center_x + width_x/2, center_x + width_x/2], [center_y - width_y/2, center_y + width_y/2], color='k') # left

    depth = len(moleculess)

    delta = 1/n
    a = 0.3
    b = 0.15
    y0 = 0.5


    # draw rectangles
    for l in range(depth-1):
        molecules = moleculess[l+1]
        n_molecule = len(molecules)

        centers = []

        acc_arity = 0

        for i in range(n_molecule):
            start_id = len(flatten(molecules[:i]))
            end_id = len(flatten(molecules[:i+1]))

            center_x = (start_id + (end_id - 1 - start_id)/2) * delta + delta/2
            center_y = (l+1/2)*y0
            width_x = (end_id - start_id - 1 + 2*a)*delta
            width_y = 2*b

            # add text (numbers) on rectangles
            if style == 'box':
                myrectangle(center_x, center_y, width_x, width_y)
                plt.text(center_x, center_y, properties[l][i], fontsize=15, horizontalalignment='center',
                verticalalignment='center')
            elif style == 'tree':
                # if 'GS', no rectangle, n=arity tilted lines
                # if 'Id', no rectangle, n=arity vertical lines
                # if 'Add' or 'Mul'. rectangle, "+" or "x"
                # if '', rectangle
                property = properties[l][i]
                if property == 'GS' or property == 'Add' or property == 'Mul':
                    color = 'blue'
                    arity = arities[l][i]
                    for j in range(arity):

                        if l == 0:
                            # x = (start_id + j) * delta + delta/2, center_x
                            # y = center_y - b, center_y + b
                            plt.plot([(start_id + j) * delta + delta/2, center_x], [center_y - b, center_y + b], color=color)
                        else:
                            # x = last_centers[acc_arity:acc_arity+arity], center_x
                            # y = center_y - b, center_y + b
                            plt.plot([last_centers[acc_arity+j], center_x], [center_y - b, center_y + b], color=color)

                    acc_arity += arity

                    if property == 'Add' or property == 'Mul':
                        if property == 'Add':
                            symbol = '+'
                        else:
                            symbol = '*'

                        plt.text(center_x, center_y + b, symbol, horizontalalignment='center',
                verticalalignment='center', color='red', fontsize=40)
                if property == 'Id':
                    plt.plot([center_x, center_x], [center_y-width_y/2, center_y+width_y/2], color='black')
                    
                if property == '':
                    myrectangle(center_x, center_y, width_x, width_y)



            # connections to the next layer
            plt.plot([center_x, center_x], [center_y+width_y/2, center_y+y0-width_y/2], color='k')
            centers.append(center_x)
        last_centers = copy.deepcopy(centers)

    # connections from input variables to the first layer
    for i in range(n):
        x_ = (i + 1/2) * delta
        # connections to the next layer
        plt.plot([x_, x_], [0, y0/2-width_y/2], color='k')
        plt.text(x_, -0.05*(depth-1), f'${latex(in_vars[moleculess[0][i][0]])}$', fontsize=20, horizontalalignment='center')
    plt.xlim(0,1)
    #plt.ylim(0,1);
    plt.axis('off');
    plt.show()


def test_symmetry_var(model, x, input_vars, symmetry_var):
    '''
        test symmetry
        
        Args:
        -----
            model : MultKAN, MLP or python function
            x : 2D torch.float
                inputs
            input_vars : list of sympy symbols
            symmetry_var : sympy expression

        Returns:
        --------
            cosine similarity
            
        Example
        -------
        >>> from kan.hypothesis import *
        >>> from sympy import *
        >>> model = lambda x: x[:,[0]] * (x[:,[1]] + x[:,[2]])
        >>> x = torch.normal(0,1,size=(100,8))
        >>> input_vars = a, b, c = symbols('a b c')
        >>> symmetry_var = b + c
        >>> test_symmetry_var(model, x, input_vars, symmetry_var);
        >>> symmetry_var = b * c
        >>> test_symmetry_var(model, x, input_vars, symmetry_var);
    '''
    orig_vars = input_vars
    sym_var = symmetry_var
    
    # gradients wrt to input (model)
    input_grad = batch_jacobian(model, x)

    # gradients wrt to input (symmetry var)
    func = lambdify(orig_vars, sym_var,'numpy') # returns a numpy-ready function

    func2 = lambda x: func(*[x[:,[i]] for i in range(len(orig_vars))])
    sym_grad = batch_jacobian(func2, x)
    
    # get id
    idx = []
    sym_symbols = list(sym_var.free_symbols)
    for sym_symbol in sym_symbols:
        for j in range(len(orig_vars)):
            if sym_symbol == orig_vars[j]:
                idx.append(j)
                
    input_grad_part = input_grad[:,idx]
    sym_grad_part = sym_grad[:,idx]
    
    cossim = torch.abs(torch.sum(input_grad_part * sym_grad_part, dim=1)/(torch.norm(input_grad_part, dim=1)*torch.norm(sym_grad_part, dim=1)))
    
    ratio = torch.sum(cossim > 0.9)/len(cossim)
    
    print(f'{100*ratio}% data have more than 0.9 cosine similarity')
    if ratio > 0.9:
        print('suggesting symmetry')
    else:
        print('not suggesting symmetry')
        
    return cossim

四、总结与思考

        KAN神经网络通过融合数学定理与深度学习,为科学计算和可解释AI提供了新思路。尽管在高维应用中仍需突破,但其在低维复杂函数建模上的潜力值得关注。未来可能通过改进计算效率、扩展理论边界,成为MLP的重要补充。

        1. KAN网络架构

  • 关键设计可学习的激活函数:每个网络连接的“权重”被替换为单变量函数(如样条、多项式),而非固定激活函数(如ReLU)。分层结构:输入层和隐藏层之间、隐藏层与输出层之间均通过单变量函数连接,形成多层叠加。参数效率:由于理论保证,KAN可能用更少的参数达到与MLP相当或更好的逼近效果。

  • 示例结构输入层 → 隐藏层:每个输入节点通过单变量函数\phi_{q,i} \left( x_{i} \right) 连接到隐藏节点。隐藏层 → 输出层:隐藏节点通过另一组单变量函数\psi_{q}组合得到输出。

        2. 优势与特点

  • 高逼近效率:基于数学定理,理论上能以更少参数逼近复杂函数;在低维科学计算任务(如微分方程求解)中表现优异。

  • 可解释性:单变量函数可可视化,便于分析输入变量与输出的关系;网络结构直接对应函数分解过程,逻辑清晰。

  • 灵活的函数学习:激活函数可自适应调整(如学习平滑或非平滑函数);支持符号公式提取(例如从数据中恢复物理定律)。

        3. 挑战与局限

  • 计算复杂度:单变量函数的学习(如样条参数化)可能增加训练时间和内存消耗。需要优化高阶连续函数,对硬件和算法提出更高要求。

  • 泛化能力:在高维数据(如图像、文本)中的表现尚未充分验证,可能逊色于传统MLP。

  • 训练难度:需设计新的优化策略,避免单变量函数的过拟合或欠拟合。

        4. 应用场景

  • 科学计算:求解微分方程、物理建模、化学模拟等需要高精度函数逼近的任务。

  • 可解释性需求领域:医疗诊断、金融风控等需明确输入输出关系的场景。

  • 符号回归:从数据中自动发现数学公式(如物理定律)。

        5. 与传统MLP的对比

        6. 研究进展

  • 近期论文:2024年,MIT等团队提出KAN架构(如论文《KAN: Kolmogorov-Arnold Networks》),在低维任务中验证了其高效性和可解释性。

  • 开源实现:已有PyTorch等框架的初步实现。


【作者声明】

        本文分享的论文内容及观点均来源于《KAN: Kolmogorov-Arnold Networks》原文,旨在介绍和探讨该研究的创新成果和应用价值。作者尊重并遵循学术规范,确保内容的准确性和客观性。如有任何疑问或需要进一步的信息,请参考论文原文或联系相关作者。


 【关注我们】

        如果您对神经网络、群智能算法及人工智能技术感兴趣,请关注【灵犀拾荒者】,获取更多前沿技术文章、实战案例及技术分享!

Logo

科技之力与好奇之心,共建有温度的智能世界

更多推荐