jiyuhang před 3 měsíci
revize
6b12b590ad

+ 9 - 0
.gitignore

@@ -0,0 +1,9 @@
+.idea/
+__pycache__/
+*.pyc
+*.xlsx
+*.txt
+*.pkl
+*.csv
+*.png
+*.json

+ 407 - 0
Analysis/pearsonr.py

@@ -0,0 +1,407 @@
+import csv
+import sys
+
+sys.path.append("..")
+import os
+import pandas as pd
+import config
+from Database.database_ import Database, DatabaseParam
+import json
+from scipy import stats
+import numpy as np
+import pickle
+from utils.tools import cal_vari_without_zero_nan, cal_vari_without_nan, df_is_symetry, quick_sort, load_transfer_file_name_code
+
+class DFMat:
+    """输入字段,实现从数据库中获取全部数据,核心的属性是pandas.Dataframe,融合了所有字段的数据,从数据库拿完数据后需要进行数据的清洗和预处理"""
+    def __init__(self, keys_file_dir: str, min_records:int, db_param: DatabaseParam, transfer_file_dir:str, is_from_local:bool=True):
+        self.bad_keys = config.EXCLUDE_WORDS
+        self.keys_file_dir = keys_file_dir
+        self.min_records = min_records
+        self.keys = self.load_keys()  # 升序排序
+        self.db_param = db_param
+        self.transfer_file_dir = transfer_file_dir
+        self.name_2code_dict, self.code_2name_dict = self.load_transfer_file()  # 转换字典
+        self.diff_words = config.DIFF_WORDS  # 需要差分计算的字段,如果字段中包括这些字段就进行差分平稳化
+        self.is_from_local = is_from_local
+        # 本地保存数据库数据,避免重复查询
+        self.local_df_merge_path = config.DF_MERGE_FILE_PATH
+        self.df_merge = self.__construct()  # 构建数据部分,初始化时完成
+
+    def load_keys(self):
+        keys_list = []
+        with open(self.keys_file_dir, "r", encoding="utf-8") as f:
+            csv_reader = csv.reader(f)
+            try:
+                label = next(csv_reader)
+            except StopIteration:
+                print('文件不存在:', self.keys_file_dir)
+            for row in csv_reader:
+                records_num = int(row[6])
+                records_name = row[0]
+                if records_num < self.min_records: continue
+                keys_list.append(records_name)
+        # 升序排序
+        keys_list = sorted(keys_list)
+        # 剔除列表不需要的字段
+        keys_list = self.exclude_keys(keys_list)
+        return keys_list  # 升序排列
+
+    def exclude_keys(self, keys_list:list):
+        """根据剔除列表对键入的字段进行剔除"""
+        new_keys = []
+        for name in keys_list:
+            flag = False
+            for bad_key in self.bad_keys:
+                if bad_key in name:
+                    flag = True
+                    break
+            if flag: continue
+            new_keys.append(name)
+        return new_keys
+
+    def load_transfer_file(self):
+        """加载转换文件"""
+        path = self.transfer_file_dir
+        return load_transfer_file_name_code(path)
+        # if not os.path.exists(self.transfer_file_dir):
+        #     raise FileNotFoundError('文件未发现:', self.transfer_file_dir)
+        # with open(self.transfer_file_dir, "r", encoding="utf-8") as f:
+        #     json_data = json.load(f)
+        # return json_data.get('name_2_code'), json_data.get('code_2_name')
+
+    def save_df_merge(self, data:pd.DataFrame):
+        """保存文件到本地"""
+        with open(self.local_df_merge_path, 'wb') as f:
+            pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
+        print(f'mat_shape:{data.shape},文件保存至:', self.local_df_merge_path)
+
+    def load_from_local(self) -> pd.DataFrame:
+        """从本地加载数据"""
+        with open(self.local_df_merge_path, 'rb') as f:
+            local_data = pickle.load(f)
+        return local_data
+
+    def normalize(self, data:pd.DataFrame)-> pd.DataFrame:
+        """对数据进行归一化,消除量纲影响"""
+        # 皮尔逊系数的计算不需要
+        pass
+
+    @staticmethod
+    def diff_tool(data: pd.Series):
+        """用于计算累计量的差分,单调递增"""
+        data = data.copy()
+        # 0值替换为Nan
+        data.replace([np.inf, -np.inf, 0], np.nan, inplace=True)
+        data = data.diff()
+        # 检查负值,替换为nan
+        data[data < 0] = np.nan
+        data[0] = data.mean()
+        # 将nan向前填充
+        data.ffill(inplace=True)
+        return data
+
+    def stabilize(self, data:pd.DataFrame)-> pd.DataFrame:
+        """数据差分使数据平稳化"""
+        if len(self.diff_words) == 0: return data
+
+        # 获取所有列标签, 仅保留需要做差分的字段
+        col_label_list = data.columns.tolist()
+        # 剔除时间戳字段
+        if 'time' in col_label_list:
+            col_label_list.remove('time')
+        # 查找需要平稳化的字段
+        diff_label_list = set()
+        for col in col_label_list:
+            name = self.code_2name_dict[col]
+            for dword in self.diff_words:
+                if dword in name:
+                    diff_label_list.add(col)
+        diff_label_list = list(diff_label_list)
+
+        for col in diff_label_list:
+            data.loc[:, col] = self.diff_tool(data.loc[:, col])
+
+        return data
+
+    @staticmethod
+    def remove_outliers(data:pd.Series, fill_value=0, times:int=1)-> pd.Series:
+        """剔除序列的离群点,使用fill_value进行填充"""
+        data = data.copy(deep=True)
+        for time in range(abs(times)): # 执行times次
+            # 计算均值和方差
+            mean, std_dev = cal_vari_without_nan(data)
+            fill_value = mean
+            threshold = 3 * std_dev
+            limit_top = mean + threshold
+            limit_low = mean - threshold
+            # 处理离群点
+            mask = data.notna() & (data != 0) & ((data < limit_low) | (data > limit_top))
+            # 离群点填充
+            data.loc[mask] = fill_value
+            # for idx, v in enumerate(data):
+            #     if pd.isna(v) or abs(v - 0.) < 1e-6: continue  # 0和nan不处理
+            #     if v > limit_top or v < limit_low:
+            #         data[idx] = fill_value  # 离群点填充
+        return data
+
+    def clean(self, data:pd.DataFrame)-> pd.DataFrame:
+        """对数据进行清洗,把离群值和Nan替换为平均值,0不参与该过程"""
+        # 获取所有列标签
+        col_label_list = data.columns.tolist()
+        if 'time' in col_label_list: col_label_list.remove('time')  # 不处理time列
+        # 逐列处理离群点
+        for col_label in col_label_list:
+            # 拿到列数据
+            col_series = data.loc[:, col_label]
+
+            data.loc[:, col_label] = self.remove_outliers(col_series, times=1)
+        # 统一处理nan值,使用平均值填充nan
+        cols_mean = data[col_label_list].mean()  # 自动跳过平均值
+        cols_mean = cols_mean.fillna(0)
+        data[col_label_list] = data[col_label_list].fillna(cols_mean)
+        return data
+
+    def fetch(self)->pd.DataFrame:
+        """从数据库中拿到数据,拿到原始数据,尽量不要在这里面清洗数据"""
+        # 数据库操作应在内部
+        data_names = self.keys
+        data_codes = [self.name_2code_dict.get(name) for name in data_names]
+        # 从数据库取数据
+        with Database(self.db_param) as db:  # 连接数据库
+            # 检查表是否存在
+            if not db.sheet_exists(config.DB_SHEET_NAME):
+                raise RuntimeError(f'表{config.DB_SHEET_NAME}不存在于数据库{config.DB_NAME}中!')
+            # SQL查询数据
+            group_df = db.query_sql_time_series_group2data_frame(
+                code_name_dict=self.code_2name_dict,
+                project_id=config.PROJECT_ID,
+                sheet_name=config.DB_SHEET_NAME,
+                data_codes=data_codes,
+                start_year=config.CHECK_YEAR_START,
+                end_year=config.CHECK_YEAR_END,
+                start_month=config.CHECK_MONTH_START,
+                end_month=config.CHECK_MONTH_END,
+                start_day=config.CHECK_DAY_START,
+                end_day=config.CHECK_DAY_END,
+                start_hour=config.CHECK_HOUR_START,
+                end_hour=config.CHECK_HOUR_END,
+                start_minute=config.CHECK_MINUTE_START,
+                end_minute=config.CHECK_MINUTE_END,
+                start_second=config.CHECK_SECONDS_START,
+                end_second=config.CHECK_SECONDS_END)
+        return group_df
+
+    def __construct(self):
+        """构建所有满足条件的字段dataframe"""
+        # 尝试从本地加载数据
+        if self.is_from_local:
+            if os.path.exists(self.local_df_merge_path):
+                print(f'从本地{self.local_df_merge_path}加载数据库数据')
+                return self.load_from_local()
+            else:
+                print(f'从本地{self.local_df_merge_path}加载失败,文件不存在!')
+        # 先从数据库获取数据
+        print("尝试从数据库获取数据!")
+        group_df = self.fetch()
+        # 清洗数据,消除Nan和离群值
+        group_df = self.clean(group_df)  # 把自己的引用给自己
+        # 平稳化
+        group_df = self.stabilize(group_df)  # 此时数据不存在nan
+        # 如果数据不存在就保存
+        if not os.path.exists(self.local_df_merge_path):
+            self.save_df_merge(group_df)
+        return group_df
+    def get_df_merge(self):
+        return self.df_merge
+
+
+class PearsonrMat(DFMat):
+    """实现皮尔逊相关系数矩阵,核心属性为pandas.Dataframe,要求键入key,核心的df行和列也是按照给定的keys写入"""
+    def __init__(self, keys_file_dir: str, min_records:int, db_param: DatabaseParam, transfer_file_dir:str, is_from_local:bool=True):
+        super().__init__(keys_file_dir=keys_file_dir, min_records=min_records, db_param=db_param, transfer_file_dir=transfer_file_dir, is_from_local=is_from_local)
+        self.r_mat = None
+        self.lag_mat = None
+
+    def r_mat_filter(self):
+        """将mat中只和自己相关的字段过滤掉"""
+        # 先找到需要删除的字段
+        filter_label_list = []
+        label_list = self.r_mat.columns.tolist()
+        for label in label_list:
+            r_col = self.r_mat.loc[:, label]
+            non_zero_counter = 0
+            for value in r_col:
+                if value > config.PEARSONR_VALUE_THRESHOLD:
+                    non_zero_counter += 1
+            if non_zero_counter < 2:
+                filter_label_list.append(label)
+        self.r_mat.drop(filter_label_list, axis=0, inplace=True)
+        self.r_mat.drop(filter_label_list, axis=1, inplace=True)
+
+    def pearsonr_with_lag(self, a_series_data_label: str, b_series_data_label: str):
+        """带滞后的皮尔逊计算"""
+
+        lags = config.MAX_LAG
+
+        if lags == 0:
+            left_point = 0
+            right_point = 1
+        elif lags > 0:
+            left_point = -lags
+            right_point = lags
+        else:
+            raise ValueError('最大滞后不能为负数', lags)
+
+        step = config.STEP
+        # 不同滞后下的相关系数
+        list_r_lag = []
+        for lag in range(left_point, right_point, step):
+            if lag < 0:  # a滞后于b
+                series_a_shifted = self.df_merge.loc[:, a_series_data_label][-lag:]
+                series_b_shifted = self.df_merge.loc[:, b_series_data_label][:lag]
+            elif lag > 0:  # b滞后于a
+                series_a_shifted = self.df_merge.loc[:, a_series_data_label][:-lag]
+                series_b_shifted = self.df_merge.loc[:, b_series_data_label][lag:]
+            elif lag == 0:  # 0滞后
+                series_a_shifted = self.df_merge.loc[:, a_series_data_label]
+                series_b_shifted = self.df_merge.loc[:, b_series_data_label]
+            else:
+                series_a_shifted = None
+                series_b_shifted = None
+            # 计算皮尔逊系数和显著性
+            if series_a_shifted is None or series_b_shifted is None:
+                raise RuntimeError('数据不应为None',series_a_shifted, series_b_shifted)
+            r, p_value = stats.pearsonr(series_a_shifted, series_b_shifted)
+            # 过滤不显著的数据
+            if p_value <= config.P_VALUE_THRESHOLD:
+                list_r_lag.append(np.float32(r))
+        if len(list_r_lag) > 0:
+            return max(list_r_lag)
+        else:
+            return 0
+
+    def pearsonr_(self, a_series_data_label: str, b_series_data_label: str)->float:
+        a_series_data = self.df_merge.loc[:, a_series_data_label]
+        b_series_data = self.df_merge.loc[:, b_series_data_label]
+        r, p_value = stats.pearsonr(a_series_data, b_series_data)
+        if p_value <= config.P_VALUE_THRESHOLD:  # 结果显著
+            return np.float32(r)
+        else:
+            return np.float32(0)
+
+    def skip_tool(self, series_a_name:str, series_b_name:str)->bool:
+        # 标签转换
+        series_a_name = self.code_2name_dict.get(series_a_name)
+        series_b_name = self.code_2name_dict.get(series_b_name)
+        if '温度' in series_a_name and '温度' in series_b_name: return True
+        if '次数' in series_a_name and '次数' in series_b_name: return True
+        if '累计' in series_a_name and '累计' in series_b_name: return True
+        if '电流' in series_a_name and '电流' in series_b_name: return True
+        if '电压' in series_a_name and '电压' in series_b_name: return True
+        if '电流' in series_a_name and '温度' in series_b_name: return True
+        if '温度' in series_a_name and '电流' in series_b_name: return True
+        if '累计电量' in series_a_name and '累计电量' in series_b_name: return True
+        if '运行时间' in series_a_name and '累计电量' in series_b_name: return True
+        if '累计电量' in series_a_name and '运行时间' in series_b_name: return True
+        if '运行时间' in series_a_name and '运行时间' in series_b_name: return True
+        if '时间设定' in series_a_name and '时间设定' in series_b_name: return True
+        return False
+
+    def calculate_pearsonr_mat(self):
+        """计算pearson系数"""
+        # 判断是否能够从本地读取,可以的话就不从新计算了
+        if os.path.exists(config.R_MAT_JSON_PATH):
+            print(f"皮尔逊系数矩阵从本地读取, {config.R_MAT_JSON_PATH}")
+            with open(config.R_MAT_JSON_PATH, 'rb') as f:
+                self.r_mat = pickle.load(f)
+            return
+
+        # 先算再使标签中文化
+        all_labels_code = [k for k in self.df_merge.columns.tolist() if k != 'time']
+        all_labels_name = sorted([self.code_2name_dict.get(l) for l in all_labels_code])  # 升序
+        self.r_mat = pd.DataFrame(index=all_labels_name, columns=all_labels_name, dtype=np.float32)
+        self.r_mat.fillna(0, inplace=True)  # 全部填充为0
+        for a_label_idx in range(0, len(all_labels_code), 1):  # 行标签
+            for b_label_idx in range(a_label_idx, len(all_labels_code), 1): # 列标签
+                # 检查是否属于可跳过的字段组合
+                a_label = all_labels_code[a_label_idx]
+                b_label = all_labels_code[b_label_idx]
+                if self.skip_tool(a_label, b_label):
+                    print(f'跳过组合:{a_label},{b_label}')
+                    self.r_mat.loc[self.code_2name_dict.get(a_label), self.code_2name_dict.get(b_label)] = np.float32(0)
+                # 正式计算
+                if config.IS_LAG:
+                    result = self.pearsonr_with_lag(a_label, b_label)
+                else:
+                    result = self.pearsonr_(a_label, b_label)
+                # 要保证对称性
+                self.r_mat.loc[self.code_2name_dict.get(a_label), self.code_2name_dict.get(b_label)] = result
+                self.r_mat.loc[self.code_2name_dict.get(b_label), self.code_2name_dict.get(a_label)] = result
+        # 筛选一些无关字段
+        self.r_mat_filter()
+        # 保存计算结果
+        self.save_pearsonr_mat()
+    def save_pearsonr_mat(self):
+        path = config.R_MAT_JSON_PATH
+        if os.path.exists(path):
+            os.remove(path)
+        with open(path, 'wb') as f:
+            pickle.dump(self.r_mat, f, protocol=pickle.HIGHEST_PROTOCOL)
+        print(f'mat_shape:{self.r_mat.shape},文件保存至:',path)
+
+
+    def query_r_rank_n(self, target:str, n:int=-1)->list[str]:
+        """输入target字段,从皮尔逊系数矩阵中挑选排名前n的字段, n为-1表示取所有"""
+        if self.r_mat is None:
+            raise Exception('r_mat 为None,请先计算皮尔逊系数矩阵!')
+        # 取出对应的列,皮尔逊矩阵为对称矩阵,因此取一列或者一行就可以了
+        if not df_is_symetry(self.r_mat):
+            raise RuntimeError('皮尔逊矩阵非对称,请检查计算过程!')
+        # 准备排序
+        label_list = self.r_mat.index.tolist()
+        if target not in label_list:
+            raise ValueError(f'查询字段不存在',target)
+        # 检查输入参数是否合法
+        if n == -1:
+            n = np.sum(np.abs(self.r_mat.loc[:, target].to_numpy()) > 0 )
+            n = int(n)
+        elif n <= 0:
+            raise RuntimeError('n输入值非法,应大于0',n)
+
+        elements = []
+        for row_label in label_list:
+            elements.append((row_label, self.r_mat.loc[row_label, target]))
+        # 按照皮尔逊相关系数的绝对值进行升序排序
+        quick_sort(elements, 0, len(elements) - 1)
+        # 反转list,由大到小排序
+        elements = elements[::-1]
+
+        elements = [elements[e][0] for e in range(n)]
+        return elements
+
+if __name__ == '__main__':
+    # 数据库参数
+    db_param = DatabaseParam(
+        db_host=config.DB_HOST,
+        db_user=config.DB_USER,
+        db_password=config.DB_PASSWORD,
+        db_name=config.DB_NAME,
+        db_port=config.DB_PORT)
+    # 先拿到所有的数据
+    df_mat = PearsonrMat(keys_file_dir=os.path.join(config.STATISTICS_FILE_DIR, config.STATISTICS_FILE_NAME),
+                   min_records=config.MIN_RECORDS, db_param=db_param,
+                   transfer_file_dir=os.path.join(config.ALL_ITEMS_FILE_DIR, config.TRANSFER_JSON_NAME)
+                         )
+    # 计算皮尔逊系数和显著性p值(带滞后)
+    df_mat.calculate_pearsonr_mat()
+
+    # 测试函数
+    # df_mat.query_r_rank_n('反渗透总产水电导')
+
+
+
+
+
+
+

+ 244 - 0
Analysis/regression.py

@@ -0,0 +1,244 @@
+import sys
+sys.path.append("..")
+from Analysis.pearsonr import DFMat, PearsonrMat
+from Database.database_ import DatabaseParam
+import config
+import os
+import json
+import pandas as pd
+from sklearn.preprocessing import StandardScaler
+from sklearn.linear_model import Lasso, LassoCV, LinearRegression
+from sklearn.model_selection import TimeSeriesSplit
+import numpy as np
+import matplotlib.pyplot as plt
+from sklearn.metrics import r2_score
+import scipy.stats as stats
+from utils.tools import set_chinese_font
+from sklearn.metrics import mean_squared_error, mean_absolute_error
+from sklearn.model_selection import cross_val_score
+from statsmodels.stats.outliers_influence import OLSInfluence
+import statsmodels.api as sm
+
+class RegressionBox(PearsonrMat):
+    """Lasso回归模型+OLS最小回归"""
+    def __init__(self, keys_file_dir: str, min_records:int, db_param: DatabaseParam, transfer_file_dir:str, is_from_local:bool=True):
+        super().__init__(keys_file_dir=keys_file_dir, min_records=min_records, db_param=db_param, transfer_file_dir=transfer_file_dir, is_from_local=is_from_local)
+        self.lasso_info = {"help":"x,自变量名;y,因变量名;alpha,最佳参数;coef,自变量权重;intercept,截距;n_iter,迭代次数;dual_gap,对偶间隙;tol,对偶容忍"}
+        self.ols_info =  {"help":"x,自变量名;y,因变量名;最佳参数;coef,自变量权重;intercept,截距;n_iter,迭代次数;score,R2决定系数;"}
+        self.ols_model = None  # 最终的线性OLS回归模型
+
+    def read_features_file(self):
+        """加载特征文件,确定因变量Y和自变量X的标签"""
+        path = config.LASSO_FEATURE_FILE_PATH
+        if not os.path.exists(path):
+            raise FileNotFoundError('文件未发现:', path)
+        with open(path, "r", encoding="utf-8") as f:
+            json_data = json.load(f)
+        return json_data.get('targets'), json_data.get('features')
+
+    def load_features(self)->tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
+        y_label_name, x_label_name = self.read_features_file()
+        # name转换为code
+        y_label_code = [self.name_2code_dict.get(i) for i in y_label_name if self.name_2code_dict.get(i) in self.df_merge.columns.tolist()]
+        x_label_code = [self.name_2code_dict.get(i) for i in x_label_name if self.name_2code_dict.get(i) in self.df_merge.columns.tolist()]
+        if len(y_label_code) ==0 or len(x_label_code) == 0:
+            raise ValueError('需要拟合的特征为空,请检查建模字段是否存在', (y_label_code, x_label_code))
+        targets = self.df_merge.loc[:, y_label_code].copy()
+        features = self.df_merge.loc[:, x_label_code].copy()
+        time = self.df_merge.loc[:, ['time']].copy()
+        return targets, features, time
+
+    def select_features(self):
+        pass
+
+    def lasso_(self, y_value:np.ndarray, scaler_x_value:np.ndarray,n_splits:int=5, max_iter:int=10000):
+        """实现Lasso回归分析,选择字段"""
+        tscv = TimeSeriesSplit(n_splits=n_splits)
+        # 寻找最优alphas
+        lasso_model = LassoCV(alphas=100,
+                              cv=tscv,
+                              max_iter=max_iter,
+                              random_state=42,
+                              n_jobs=-1)
+        lasso_model.fit(scaler_x_value, y_value)
+        # 记录最优alphas
+        self.lasso_info['alpha'] = lasso_model.alpha_
+        # 记录截距
+        self.lasso_info['intercept'] = lasso_model.intercept_
+        # 记录迭代次数
+        self.lasso_info['n_iter'] = lasso_model.n_iter_
+        # 记录对偶间隙
+        self.lasso_info['dual_gap'] = lasso_model.dual_gap_
+        # 记录对偶容忍
+        self.lasso_info['tol'] = lasso_model.tol
+        # 记录权重
+        self.lasso_info['coef'] = lasso_model.coef_
+
+    def ols_(self, y_value:np.ndarray, scaler_x_value:np.ndarray)->LinearRegression:
+        """OLS回归"""
+        model = LinearRegression()
+        model.fit(scaler_x_value, y_value)
+        # 记录截距
+        self.ols_info['intercept'] = model.intercept_
+        # 记录权重
+        self.ols_info['coef'] = model.coef_
+        # 记录R²
+        self.ols_info['score'] = model.score(scaler_x_value, y_value)
+        return model
+
+    def any_regression_full(self, target_name:str):
+        """对任意输入字段进行全字段回归建模"""
+        pass
+
+    def any_regression_r_rank(self, target_name:str):
+        """基于皮尔逊系数排序对字段进行回归建模"""
+        # 所有需要建模的字段
+        y_label_name = target_name
+        x_label_name = self.query_r_rank_n(y_label_name)  # 根据皮尔逊排序挑选相关性字段
+        # 剔除自身字段
+        if y_label_name in x_label_name:
+            x_label_name.remove(y_label_name)
+
+        # 拿到数据
+        y_label_code = self.name_2code_dict[y_label_name]
+        x_label_code = [self.name_2code_dict.get(name) for name in x_label_name]
+
+        y = self.df_merge.loc[:, y_label_code].copy()  # 真实值
+        y = y.to_numpy()
+        x = self.df_merge.loc[:, x_label_code].copy()  # 预测值
+        t = self.df_merge.loc[:, 'time'].copy()   # 时间序列
+
+        # 标准化
+        scaler = StandardScaler()
+        x = scaler.fit_transform(x)
+
+        # Lasso回归,选择字段
+        self.lasso_(y_value=y, scaler_x_value=x)
+        self.lasso_info['x'] = x_label_name
+        self.lasso_info['y'] = y_label_name
+
+        # Lasso模型诊断与可视化
+        print('\n===========Lasso训练结果==================')
+        print(f'最优lambda:{self.lasso_info.get('alpha')}')
+        print(f'Y:{self.lasso_info.get('y')}')
+        print(f"Lasso系数:")
+        for feat, coef in zip(x_label_name, self.lasso_info.get('coef')):
+            print(f"  {feat}: {coef}")
+        print(f'截距:{self.lasso_info.get('intercept')}')
+        print(f'迭代次数:{self.lasso_info.get('n_iter')}')
+        print(f'对偶间隙:{self.lasso_info.get('dual_gap')}')
+        print(f'对偶间隙容忍:{self.lasso_info.get('tol')}')
+
+        # OLS回归,筛选系数不为零的向量
+        mask = self.lasso_info.get('coef') != 0
+        x_label_name = list(np.array(x_label_name)[mask])
+        x_label_code = list(np.array(x_label_code)[mask])
+
+        x = self.df_merge.loc[:, x_label_code]  # 没进行归一化/标准化
+        self.ols_model = self.ols_(y_value=y, scaler_x_value=x)
+        self.ols_info['x'] = x_label_name
+        self.ols_info['y'] = y_label_name
+
+
+        # OLS模型诊断
+        print('\n===========OLS训练结果==================')
+        print(f"OLS 截距: {self.ols_info.get('intercept')}")
+        print(f"OLS 系数:")
+        for feat, coef in zip(x_label_name, self.ols_info.get('coef')):
+            print(f"  {feat}: {coef}")
+        print(f"OLS R² (训练集): {self.ols_info.get('score'):.4f}")
+
+        # 基本指标评价
+        y_pred = self.ols_model.predict(x)
+        residuals = y - y_pred
+        mse = mean_squared_error(y, y_pred)
+        rmse = np.sqrt(mse)
+        mae = mean_absolute_error(y, y_pred)
+        r2 = r2_score(y, y_pred)
+        # 调整R²
+        n = len(y)
+        p = x.shape[1]
+        adj_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
+        print("\n===========模型性能指标==================:")
+        print(f"均方误差 (MSE): {mse:.4f}")
+        print(f"均方根误差 (RMSE): {rmse:.4f}")
+        print(f"平均绝对误差 (MAE): {mae:.4f}")
+        print(f"决定系数 (R²): {r2:.4f}")
+        print(f"调整R²: {adj_r2:.4f}")
+
+        # 创建诊断图
+        set_chinese_font()
+        fig, axes = plt.subplots(2, 3, figsize=(15, 10))
+
+        # 1. 残差 vs 拟合值图(检查同方差性和线性关系)
+        axes[0, 0].scatter(y_pred, residuals, alpha=0.6)
+        axes[0, 0].axhline(y=0, color='red', linestyle='--')
+        axes[0, 0].set_xlabel('拟合值')
+        axes[0, 0].set_ylabel('残差')
+        axes[0, 0].set_title('残差 vs 拟合值')
+
+        # 2. 正态Q-Q图(检查残差正态性)
+        stats.probplot(residuals, dist="norm", plot=axes[0, 1])
+        axes[0, 1].set_title('Q-Q图(检查正态性)')
+
+        # 3. 残差直方图
+        axes[0, 2].hist(residuals, bins=30, density=True, alpha=0.7)
+        axes[0, 2].set_xlabel('残差')
+        axes[0, 2].set_ylabel('密度')
+        axes[0, 2].set_title('残差分布')
+
+        # 4. 观测值 vs 拟合值
+        axes[1, 0].scatter(y, y_pred, alpha=0.6)
+        min_val = min(y.min(), y_pred.min())
+        max_val = max(y.max(), y_pred.max())
+        axes[1, 0].plot([min_val, max_val], [min_val, max_val], 'red', linestyle='--')
+        axes[1, 0].set_xlabel('实际值')
+        axes[1, 0].set_ylabel('预测值')
+        axes[1, 0].set_title('实际值 vs 预测值')
+        r2 = r2_score(y, y_pred)
+        axes[1, 0].text(0.05, 0.95, f'R² = {r2:.3f}', transform=axes[1, 0].transAxes)
+
+        # 5. 残差的时间序列图(如果是时间序列数据)
+        axes[1, 1].plot(residuals)
+        axes[1, 1].axhline(y=0, color='red', linestyle='--')
+        axes[1, 1].set_xlabel('时间/观测序号')
+        axes[1, 1].set_ylabel('残差')
+        axes[1, 1].set_title('残差时间序列')
+
+        # 6. 尺度-位置图(检查同方差性)
+        standardized_residuals = residuals / np.std(residuals)
+        axes[1, 2].scatter(y_pred, np.sqrt(np.abs(standardized_residuals)), alpha=0.6)
+        axes[1, 2].set_xlabel('拟合值')
+        axes[1, 2].set_ylabel('√|标准化残差|')
+        axes[1, 2].set_title('尺度-位置图')
+
+        plt.tight_layout()
+        plt.show()
+        pass
+
+
+
+    def any_regression_custom(self, target_name:str, path:str):
+        """基于自定义字段进行回归建模,从文件读入建模字段"""
+
+    def auto_fit(self, x_label_name:str, y_label_name:str, is_use_lasso:bool=True):
+        """回归分析"""
+
+
+if __name__ == '__main__':
+    # 数据库参数
+    db_param = DatabaseParam(
+        db_host=config.DB_HOST,
+        db_user=config.DB_USER,
+        db_password=config.DB_PASSWORD,
+        db_name=config.DB_NAME,
+        db_port=config.DB_PORT)
+
+    my_box = RegressionBox(
+        keys_file_dir=os.path.join(config.STATISTICS_FILE_DIR, config.STATISTICS_FILE_NAME),
+        min_records = config.MIN_RECORDS, db_param = db_param,
+        transfer_file_dir = os.path.join(config.ALL_ITEMS_FILE_DIR, config.TRANSFER_JSON_NAME))
+    # 计算皮尔逊
+    my_box.calculate_pearsonr_mat()
+    # 进行回归分析
+    my_box.any_regression_r_rank("RO1脱盐率")

+ 170 - 0
Database/database_.py

@@ -0,0 +1,170 @@
+import sys
+sys.path.append("..")
+import pandas as pd
+import pymysql
+from utils.tools import fmt_date
+import config
+
+class DatabaseParam:
+    def __init__(self, db_user: str, db_password: str, db_host: str, db_name: str, db_port: int, db_charset: str='utf8mb4'):
+        self.db_user = db_user
+        self.db_password = db_password
+        self.db_host = db_host
+        self.db_name = db_name
+        self.db_port = db_port
+        self.db_charset = db_charset
+
+    @property
+    def params(self) -> dict:
+        # 执行一些转换或者参数检查, 待补充
+        pass
+        return {'db_user': self.db_user,
+                'db_password': self.db_password,
+                'db_host': self.db_host,
+                'db_name': self.db_name,
+                'db_port': self.db_port,
+                'db_charset': self.db_charset
+        }
+
+
+class Database:
+    def __init__(self, params: DatabaseParam):
+        self.params = params.params  # 参数
+        self.db_conn = None  # 连接
+        self.cursor = None  # 游标
+
+    def __enter__(self):
+        try:
+            # 连接失败仍为None
+            self.db_conn = pymysql.connect(host=self.params.get('db_host'),
+                                         user=self.params.get('db_user'),
+                                         password=self.params.get('db_password'),
+                                         database=self.params.get('db_name'),
+                                         port=self.params.get('db_port'),
+                                         charset='utf8mb4')
+            self.db_cursor = self.db_conn.cursor()
+        except pymysql.MySQLError as e:
+            print('数据库连接失败:', e)
+            print(f'请检查 host: {self.params.get('db_host')}, user: {self.params.get('db_user')}, password: , database: {self.params.get('db_name')}, port: {self.params.get('db_port')}')
+            return None
+        if self.db_cursor and self.db_conn: print(f'数据库{self.params.get('db_name')}已连接!')
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+
+        if self.db_cursor:
+            self.db_cursor.close()
+            self.db_cursor = None
+        if self.db_conn:
+            self.db_conn.close()
+            self.db_conn = None
+        if self.db_cursor is None and self.db_conn is None: print(f'数据库{self.params.get('db_name')}已断开!')
+    def sheet_exists(self, sheet_name: str) -> bool:
+        sql = f"""SHOW TABLES FROM {self.params.get('db_name')} LIKE '{sheet_name}'"""
+        self.db_cursor.execute(sql)
+        result = self.db_cursor.fetchall()
+        if len(result) == 0:
+            return False
+        else:
+            return True
+    def query_sql_time_series2data_frame(self,
+                              project_id:int ,
+                              sheet_name:str ,
+                              data_code:str,
+                              start_year:int , end_year:int ,
+                              start_month:int=1, end_month:int=12,
+                              start_day:int=1, end_day:int=31,
+                              start_hour:int=0, end_hour:int=23,
+                              start_minute:int=0, end_minute:int=59,
+                              start_second:int=0, end_second:int=59):
+        # 时间格式化
+        start_datetime, end_datetime = fmt_date(start_year=start_year,start_month=start_month,start_day=start_day,
+                 end_year=end_year,end_month=end_month,end_day=end_day,
+                 start_hour=start_hour,start_minute=start_minute,start_second=start_second,
+                 end_hour=end_hour,end_minute=end_minute,end_second=end_second)
+        # 查询语句
+        sql = f"""SELECT * FROM {sheet_name} WHERE item_name = '{data_code}' AND project_id = '{project_id}' AND h_time >= '{start_datetime}' AND h_time <= '{end_datetime}'"""
+        #print(sql)
+        if self.db_cursor is None: raise TypeError('数据库可能未连接,值不能为None.', self.db_cursor)
+        # 查询数据
+        self.db_cursor.execute(sql)
+        result = self.db_cursor.fetchall()
+        #result = self.db_cursor.fetchmany(3)
+        df = pd.DataFrame(result, columns=[desc[0] for desc in self.db_cursor.description])
+
+        if not len(df):
+            print(f'查询到0条数据,序列标签:{data_code}, 数量:{len(df)}')
+            return None
+        if df.iloc[0]['item_name'].strip() != data_code:
+            raise RuntimeError(f'数据库中序列名称与输入不一致,输入:{data_code}, 数据库:{df.iloc[0]['item_name']}')
+        # # 消除Nan
+        # df.dropna(subset=['val'], inplace=True)  # 不要在数据库这里消除Nan
+        # 修改标签
+        val_label = df.iloc[0]['item_name'].strip()# + '_val'
+        df.rename(columns={'val': f'{val_label}', 'h_time':'time'}, inplace=True)
+        # 删除无关列
+        df.drop(columns=['project_id', 'item_name'], inplace=True, axis=1)
+        # 转换值数据类型
+        df[val_label] = df[val_label].astype("float32")
+        return df[['time', val_label]]
+
+    def query_sql_time_series_group2data_frame(self,
+                              code_name_dict: dict,
+                              project_id:int ,
+                              sheet_name:str ,
+                              data_codes:list,
+                              start_year:int , end_year:int ,
+                              start_month:int=1, end_month:int=1,
+                              start_day:int=1, end_day:int=1,
+                              start_hour:int=0, end_hour:int=0,
+                              start_minute:int=0, end_minute:int=0,
+                              start_second:int=0, end_second:int=0):
+        """从数据库中查询多个字段,返回统一结果"""
+        frame_list = []
+        data_codes = set(data_codes)
+        for data_code in data_codes:
+            frame = self.query_sql_time_series2data_frame(project_id=project_id,
+                                                  sheet_name=sheet_name,
+                                                  data_code=str(data_code),
+                                                  start_year=start_year,end_year=end_year,
+                                                  start_month=start_month,end_month=end_month,
+                                                  start_day=start_day, end_day=end_day,
+                                                  start_hour=start_hour, end_hour=end_hour,
+                                                  start_minute=start_minute, end_minute=end_minute,
+                                                  start_second=start_second, end_second=end_second,
+                                                )
+            if frame is None: continue
+            # 过滤常数序列
+            if frame[frame.columns[1]].nunique() <= 2:
+                print(f'跳过常数列{frame.columns[1]}')
+                continue
+            frame_list.append(frame)
+        # 融合所有字段
+        if len(frame_list) == 0: return None
+
+        df_merge = frame_list[0]
+        for i in range(1, len(frame_list)):
+            df_merge = pd.merge(df_merge, frame_list[i], how='outer', on='time') # 外连接融合所有结果
+        # 按照日期排序
+        df_merge.sort_values('time', kind='mergesort', inplace=True)
+
+        return df_merge
+
+if  __name__ == '__main__':
+    # 创建参数
+    db_param = DatabaseParam(
+        db_host= '192.168.50.4',
+        db_user='root',
+        db_password='*B-@p2b+97D5xAF1e6',
+        db_name='ws_data',
+        db_port=4000)
+
+    # 数据库操作应在内部
+    with Database(db_param) as db:
+        df_ = db.query_sql_time_series2data_frame(92,
+                                 'dc_item_history_data_day',
+                                 'QSWGB3_n',
+                                 2025, 2025,
+                                 3, 9,
+                                 25,10)
+        print(df_)

+ 231 - 0
GetItem/get_all_items.py

@@ -0,0 +1,231 @@
+import sys
+sys.path.append("..")
+import config
+import os
+import requests
+import time
+import csv
+from  datetime import datetime
+import shutil
+import json
+
+class DataHelper:
+    """采用爬虫方式,动态获取smart-water网站某项目的各传感器数据库标签和对应的中文名称
+    项目代码 :92, 锡山中荷污水再生水项目
+    """
+    def __init__(self,
+                 project_id = config.PROJECT_ID,
+                 username = config.USERNAME,
+                 password = config.PASSWORD,
+                 dep_id = config.DEP_ID,
+                 base_url = config.BASE_URL,
+                 out_path = config.ALL_ITEMS_FILE_DIR,
+                 out_file_name = config.ALL_ITEMS_FILE_NAME,
+                 save_path_final = config.ALL_ITEMS_FILE_PATH,
+                 max_pages = config.MAX_PAGES,
+                 page_size = config.PAGE_SIZE,
+                 include_head = config.INCLUDE_HEAD
+                 ):
+        print('开始获取项目所有的数据编号...')
+        self.username = username
+        self.password = password
+        self.dep_id = dep_id
+        self.project_id = project_id
+        self.BASE_URL = base_url  #smart-water 网站首页
+        self.out_path = out_path
+        self.out_file_name = out_file_name
+        self.max_pages = int(max_pages)
+        self.page_size = int(page_size)
+        self.token = None
+        self.include_head = include_head
+        self.save_path_tem = os.path.join(self.out_path,'tem_' + self.out_file_name)
+        self.save_path_final = save_path_final
+        self.start_time = time.time()
+        self.end_time = self.start_time
+        # 清理上一次执行的结果文件
+        if os.path.exists(self.save_path_tem) or os.path.exists(self.save_path_final):
+            print(f'清理缓存文件...')
+            if os.path.exists(self.save_path_final):
+                os.remove(self.save_path_final)
+                print(f'清理  {self.save_path_final}')
+            if os.path.exists(self.save_path_tem):
+                os.remove(self.save_path_tem)
+                print(f'清理  {self.save_path_tem}')
+
+    def login_smart_water(self):
+        login_url = f"{self.BASE_URL}/api/v2/user/login"  # smart-water 登陆页面
+        login_headers = {  # 登陆请求头
+            "Accept": "application/json",
+            "Content-Type": "application/json;charset=utf-8",
+            "Cookie": "lang=zh-CN",
+            "Origin": self.BASE_URL,
+            "Referer": f"{self.BASE_URL}/",
+            "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36"
+        }
+        login_params = {  # 请求参数
+            "username": self.username,
+            "password": self.password,
+            "type": "account",
+            "DepId": self.dep_id  # 部门ID
+        }
+        try:
+            # 尝试登陆
+            response = requests.post(login_url, json=login_params, headers=login_headers)
+            response.raise_for_status()  # 检查HTTP错误
+            data = response.json()
+            token = data['data']['token']
+            self.token = token if token != '' else None
+            if self.token is not None:
+                print(f'{self.username} 登陆成功! \n获取token {self.token}')
+            else:
+                print(f'{self.username} 登陆失败!')
+
+        except requests.exceptions.HTTPError as errh:
+            print("HTTP Error:", errh)
+        except requests.exceptions.ConnectionError as errc:
+            print("Error Connecting:", errc)
+        except requests.exceptions.Timeout as errt:
+            print("Timeout Error:", errt)
+        except requests.exceptions.RequestException as err:
+            print("OOps: Something Else", err)
+        return None
+
+    @staticmethod
+    def write_file(handler, data: list):
+        write_cnt = 0
+        for label in data:
+            # '名称', '编码', '单位' , '精度', '是否枚举', '设备号'
+            csv.writer(handler).writerow([label['ItemAlias'], label['ItemName'], label['ItemUnit'], label['ItemPrecise'], int(label['IsBool']), label['DeviceCode']])
+            write_cnt += 1
+        return write_cnt
+    @staticmethod
+    def format_chinese_datetime(dt=None):
+        """格式化日期时间为中文格式"""
+        if dt is None:
+            dt = datetime.now()
+        # 提取日期时间各部分
+        year = dt.year
+        month = dt.month
+        day = dt.day
+        hour = dt.hour
+        minute = dt.minute
+        # 格式化为中文
+        return f"{year}年{month}月{day}日 {hour:02d}:{minute:02d}"
+    def get_all_label(self):
+        if self.token is None:
+            self.login_smart_water()
+        label_url = f"{self.BASE_URL}/api/v1/config/device-realtime-plc-item/list/{self.project_id}"  # 数据抓取页面
+
+        headers = {
+            'Accept': '*/*',
+            'Accept-Encoding': 'gzip, deflate',
+            'Accept-Language': 'zh-CN,zh;q=0.9',
+            'Connection': 'keep-alive',
+            'Cookie': 'lang=zh-CN',
+            'Host': '120.55.44.4:8900',
+            'JWT-TOKEN': self.token,
+            'Referer': 'http://120.55.44.4:8900/',
+            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/140.0.0.0 Safari/537.36'
+        }
+        with requests.Session() as session:
+            session.headers.update(headers)
+            failed_cnt = 0
+            # 爬取每个页面
+            with open(self.save_path_tem, mode='a', encoding='utf-8', newline='') as file_handler:
+                # 按照'名称', '编码', '单位' , '精度', '设备号' 格式保存数据
+                print('准备写入数据...')
+                csv.writer(file_handler).writerow(['名称', '编码', '单位' , '精度', '是否枚举', '设备号'])
+                pages = 1
+                total_write_cnt = 0
+                while pages <= self.max_pages:
+                    try:
+                        params = {
+                            'currentPage': f'{pages}',
+                            'pageSize': f'{self.page_size}',
+                            'ProjectId': self.project_id,
+                            'time': int(time.time() * 1000)
+                        }
+                        response = session.get(label_url, params=params)
+                        response.raise_for_status()
+                        result = response.json()
+
+                        if result.get('code') == 603:  # token 过期就重新登录一次
+                            self.login_smart_water()
+                            headers['JWT-TOKEN'] = self.token
+                            session.headers.update(headers)
+                        if result.get('code') == 200:
+                            print(f'时间:{params['time']} 页码:{params['currentPage']}, 网页数据获取成功, 写入文件')
+                            label_list = result['data']['list']
+                            total_write_cnt += self.write_file(file_handler, label_list)
+                            pages += 1
+
+                    except requests.exceptions.HTTPError as errh:
+                        print("HTTP Error:", errh)
+                        failed_cnt += 1
+                    except requests.exceptions.ConnectionError as errc:
+                        print("Error Connecting:", errc)
+                        failed_cnt += 1
+                    except requests.exceptions.Timeout as errt:
+                        print("Timeout Error:", errt)
+                        failed_cnt += 1
+                    except requests.exceptions.RequestException as err:
+                        print("OOps: Something Else", err)
+                        failed_cnt += 1
+                    finally:
+                        if failed_cnt >= 3 :
+                            print('失败次数达到3次, 自动退出!')
+                            break
+                print(f'数据写入完成,写入网页数量为{pages - 1}页,{total_write_cnt}条数据记录!')
+
+            # 写最终文件
+            self.end_time = time.time()
+            total_time = round(self.end_time - self.start_time, 2)
+            current_date = self.format_chinese_datetime()
+            stat_info = f"# 项目编号: {self.project_id}, 获取日期: {current_date}, 总记录数量: {total_write_cnt}, 总耗时: {total_time}s"
+            with open(self.save_path_tem, mode='r', encoding='utf-8') as file_handler:
+                with open(self.save_path_final, mode='w', encoding='utf-8', newline='') as final_file_handler:
+                    if self.include_head: final_file_handler.write(stat_info + '\n')
+                    # 复制临时文件内容到最终文件
+                    shutil.copyfileobj(file_handler, final_file_handler)
+            os.unlink(self.save_path_tem)
+            print('all-items文件写入成功:',self.save_path_final)
+
+    def get_name_code_transfer(self):
+        """生成code和name之间的转换文件"""
+        total_name_to_code = {'name_2_code': {},
+                              'code_2_name': {},
+                              'len': 0}
+        if not os.path.exists(self.save_path_final):
+            raise RuntimeError('文件不存在:', self.save_path_final)
+        file_path_out = config.TRANSFER_JSON_NAME
+        # file_path_out = self.save_path_final[:-4] + '_name_code_transfer.json'
+        if os.path.exists(file_path_out):
+            print('清理历史文件:', file_path_out)
+            os.remove(file_path_out)
+        with open(self.save_path_final, 'r', encoding='utf-8') as file_handler:
+            csv_reader = csv.reader(file_handler)
+            if self.include_head:
+                try:
+                    next(csv_reader)
+                except StopIteration:
+                    pass
+            try:
+                next(csv_reader)
+            except StopIteration:
+                pass
+
+            for row in csv_reader:
+                total_name_to_code.get('name_2_code').update({row[0].strip(): row[1].strip()})
+                total_name_to_code['len'] += 1
+        total_name_to_code.get('code_2_name').update({v: k for k, v in total_name_to_code.get('name_2_code').items()})
+        with open(file_path_out, 'w', encoding="utf-8",newline='') as f:
+            json.dump(total_name_to_code, f, ensure_ascii=False, indent=4)
+        print('name-code字典文件写入成功:',file_path_out)
+
+
+if __name__ == '__main__':
+    # 从智慧水萝卜网站获取数据库中的数据字段英文编号和中文名称
+    dh = DataHelper()
+    dh.get_all_label()
+    # 生成code-name字典文件
+    dh.get_name_code_transfer()

+ 84 - 0
GetItem/get_items_distribution_from_database.py

@@ -0,0 +1,84 @@
+import sys
+sys.path.append("..")
+from Database.database_ import Database, DatabaseParam
+import config
+import os
+import csv
+from utils.tools import fmt_date
+
+
+def add_stat_info():
+    # 读取字段文件
+    file_path = os.path.join(config.ALL_ITEMS_FILE_DIR, config.ALL_ITEMS_FILE_NAME)
+    if not os.path.exists(file_path):
+        raise RuntimeError('文件不存在 ', file_path)
+    # 格式化起始结束日期
+    start_date, end_date = fmt_date(
+        start_year=config.CHECK_YEAR_START, end_year=config.CHECK_YEAR_END,start_month=config.CHECK_MONTH_START,end_month=config.CHECK_MONTH_END,start_day=config.CHECK_DAY_START,end_day=config.CHECK_DAY_END,
+        start_hour=config.CHECK_HOUR_START, end_hour=config.CHECK_HOUR_END, start_minute=config.CHECK_MINUTE_START, end_minute=config.CHECK_MINUTE_END, start_second=config.CHECK_SECONDS_START, end_second=config.CHECK_SECONDS_END)
+    print(f'统计起始日期:{start_date};结束日期:{end_date}')
+    # 连接数据库
+    db_param = DatabaseParam(
+        db_host=config.DB_HOST,
+        db_user=config.DB_USER,
+        db_password=config.DB_PASSWORD,
+        db_name=config.DB_NAME,
+        db_port=config.DB_PORT)
+    # 数据库操作应在内部
+    with Database(db_param) as db:   # 连接数据库
+        with open(file_path, 'r', encoding='utf-8') as file_handler:
+            csv_reader = csv.reader(file_handler)  # all_items文件读取器
+            # 先读一行头
+            if config.INCLUDE_HEAD:
+                try:
+                    head = next(csv_reader)
+                except StopIteration:
+                    print(f'{file_path} 文件为空')
+            # 读标签
+            try:
+                label = next(csv_reader)
+            except StopIteration:
+                print(f'{file_path} 文件为空')
+            # 添加字段
+            label += ['记录数', '最小时间', '最大时间']
+            stat_file_path = os.path.join(config.STATISTICS_FILE_DIR,config.STATISTICS_FILE_NAME)
+            if os.path.exists(stat_file_path):
+                print('清理历史文件:',stat_file_path)
+                os.remove(stat_file_path)
+            with open(stat_file_path, 'w', encoding='utf-8', newline='') as file_handler_stat:
+                # 写入头
+                csv.writer(file_handler_stat).writerow(label)
+                # 获取数据库所有的字段
+                sql = f"""SELECT DISTINCT item_name FROM {config.DB_SHEET_NAME} WHERE project_id = {config.PROJECT_ID}"""
+                db.db_cursor.execute(sql)
+                db_items = [item[0].strip() for item in db.db_cursor.fetchall()]
+                # 逐行查询数据库并写入文件
+                print('正在统计... ', end='')
+                for num, row in enumerate(csv_reader):  # all_items文件读取器
+                    data_code = row[1].strip()  # 表格中的数据编码
+                    if not data_code in db_items: continue  # 如果字段不在数据库,那么就直接跳过
+                    # 统计数量
+                    sql = f"""SELECT COUNT(*) FROM {config.DB_SHEET_NAME} WHERE item_name = '{data_code}' AND project_id = '{config.PROJECT_ID}' AND h_time >= '{start_date}' AND h_time <= '{end_date}'"""
+                    db.db_cursor.execute(sql)
+                    query_count = db.db_cursor.fetchone()[0]
+                    row.append(query_count)
+                    # 统计最小时间
+                    sql = f"""SELECT MIN(h_time) FROM {config.DB_SHEET_NAME} WHERE item_name = '{data_code}' AND project_id = '{config.PROJECT_ID}' AND h_time >= '{start_date}' AND h_time <= '{end_date}'"""
+                    db.db_cursor.execute(sql)
+                    query_min_date = db.db_cursor.fetchone()[0]
+                    row.append(query_min_date)
+                    # 统计最大时间
+                    sql = f"""SELECT MAX(h_time) FROM {config.DB_SHEET_NAME} WHERE item_name = '{data_code}' AND project_id = '{config.PROJECT_ID}' AND h_time >= '{start_date}' AND h_time <= '{end_date}'"""
+                    db.db_cursor.execute(sql)
+                    query_max_date = db.db_cursor.fetchone()[0]
+                    row.append(query_max_date)
+                    csv.writer(file_handler_stat).writerow(row)
+                    print('.', end='')
+                print('\n统计完成,文件保存至:',stat_file_path)
+
+
+
+if __name__ == '__main__':
+    # 从数据库中统计各字段数据量
+    add_stat_info()
+

+ 210 - 0
ResultShow/show.py

@@ -0,0 +1,210 @@
+import os.path
+import sys
+sys.path.append("..")
+import pandas as pd
+import config
+import pickle
+from utils.tools import create_custom_heatmap, set_chinese_font, group_list, quick_sort, load_transfer_file_name_code
+import csv
+from sklearn.preprocessing import StandardScaler
+from sklearn.linear_model import LinearRegression
+from sklearn.metrics import mean_squared_error, mean_absolute_error
+from sklearn.metrics import r2_score
+import numpy as np
+from openpyxl import load_workbook
+
+set_chinese_font()
+
+def load_pearsonr_mat():
+    with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
+        results = pickle.load(f)
+    return results
+
+def show_all_results():
+    """展示所有结果"""
+    # 加载计算结果
+    with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
+        results = pickle.load(f)
+    label_list = results.columns.tolist()
+    # 行列分组显示热力图
+    row_group_elements_num = 25  # 行分组
+    row_group_name = group_list(label_list, row_group_elements_num)
+    col_group_elements_num = 50  # 列分组
+    col_group_name = group_list(label_list, col_group_elements_num)
+
+    for i, row_group in enumerate(row_group_name):
+        for j, col_group in enumerate(col_group_name):
+            corr_matrix = results.loc[row_group, col_group]
+            create_custom_heatmap(corr_matrix, title=f'{config.PROJECT_ID}_水厂数据相关系数热力图{i}-{j}')
+
+    query_name = input('是否继续查询结果所在位置(y/n):')
+    if query_name == 'y' or query_name == 'Y':
+        while query_name != '退出':
+            query_name = input('查询:')
+            flag = False
+            for i, row_group in enumerate(row_group_name):
+                if query_name in row_group:
+                    flag = True
+                    print(f'位置:{i}-*.png')
+                    break
+            if not flag:
+                print(f'位置:{query_name}不在统计范围')
+
+def save_txt(path):
+    """按照某一个格式写入txt文件"""
+    with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
+        results = pickle.load(f)
+    label_list = results.columns.tolist()
+
+    with open(path, 'w', encoding='utf-8') as f:
+        for i in range(len(label_list)):
+            for j in range(len(label_list)):
+                r = results.iloc[i,j]
+                if abs(r-1) < 1e-6 or r < 0.2: continue
+                f.write(f'{label_list[i]}-{label_list[j]}:{r:.2f};')
+
+def save_csv(path):
+    with open(os.path.join(config.R_MAT_JSON_FILE_DIR, config.R_MAT_JSON_FILE_NAME), 'rb') as f:
+        results = pickle.load(f)
+    label_list = results.columns.tolist()
+
+def rank(path):
+    """按照相关系数进行排序,排序后写入文件"""
+    if os.path.exists(path):
+        os.remove(path)
+    rmat = load_pearsonr_mat()  # 对称矩阵
+    label_list = rmat.columns.tolist()
+    with open(path, 'w', newline='', encoding='utf-8') as f:
+        csv_writer = csv.writer(f)
+        for col_label in label_list:
+            # 从皮尔逊矩阵挑选出1列元素
+            elements = []
+            for row_label in label_list:
+                elements.append((row_label, rmat.loc[row_label, col_label]))
+            # 按照皮尔逊相关系数的绝对值进行升序排序
+            quick_sort(elements, 0, len(elements) - 1)
+            # 反转list,由大到小排序
+            elements = elements[::-1]
+            # 写入csv
+            csv_line_content = [f'{tup[0]} | {tup[1]:.2f}' for tup in elements if abs(tup[1]) > 0]
+            csv_writer.writerow([col_label] + csv_line_content)
+
+def directed_heatmap(series_a_name, series_b_name):
+
+    # 定向绘制皮尔逊系数矩阵
+    rmat = load_pearsonr_mat()
+    series_a_name = [i for i in series_a_name if i in rmat.columns.tolist()]
+    series_b_name = [i for i in series_b_name if i in rmat.columns.tolist()]
+    corr_matrix = rmat.loc[series_a_name, series_b_name]
+    create_custom_heatmap(corr_matrix, title=f'{config.PROJECT_ID}_PearsonMat-{'_'.join(series_a_name[:3])}等-VS-{'_'.join(series_b_name[:3])}等')
+
+def free_ols(target_name, x_name):
+    """自由最小二乘"""
+    # 剔除自身字段
+    if target_name in x_name:
+        x_name.remove(target_name)
+    # 获取数据
+    with open(config.DF_MERGE_FILE_PATH, 'rb') as f:
+        df_merge_mat = pickle.load(f)
+    name_2_code_dict, code_2_name_dict  = load_transfer_file_name_code(os.path.join(config.ALL_ITEMS_FILE_DIR, config.TRANSFER_JSON_NAME))
+
+    if target_name not in name_2_code_dict.keys():
+        raise RuntimeError('输入的target字段与数据不匹配', target_name)
+
+    x_name = [i for i in x_name if i in name_2_code_dict.keys()]
+
+    target_code = name_2_code_dict.get(target_name)
+    if target_code not in df_merge_mat.columns.tolist():
+        return
+    x_code = [name_2_code_dict.get(i) for i in x_name if name_2_code_dict.get(i) in df_merge_mat.columns.tolist()]
+    if len(x_name) == 0 or len(x_code) == 0:
+        raise RuntimeError('输入的x字段与数据不匹配', x_name)
+
+    #ols
+    # 标准化
+    x = df_merge_mat.loc[:, x_code].copy()
+    y = df_merge_mat.loc[:, target_code]
+    scaler = StandardScaler()
+    x = scaler.fit_transform(x)
+    ols_model = LinearRegression()
+    ols_model.fit(x, y)
+
+    # OLS模型诊断
+    print_info = []
+    print('\n===========OLS训练结果==================')
+    print(f'Y:{target_name}')
+    print(f"OLS 截距: {ols_model.intercept_}")
+    print(f"OLS 系数:")
+    for feat, coef in zip(x_name, ols_model.coef_):
+        print(f"  {feat}: {coef:.4f}")
+        print_info.append(f'{coef:.4f}*{feat}')
+    print(f"OLS R² (训练集): {ols_model.score(x,y):.4f}")
+    print_info = ['+'+i if i[0]!='-' else i for i in print_info]
+    print(f"{target_name}="+''.join(print_info) + f'+{ols_model.intercept_:.4}' if str(ols_model.intercept_)[0] != '-' else f'{ols_model.intercept_:.4}')
+
+    # 基本指标评价
+    y_pred = ols_model.predict(x)
+    residuals = y - y_pred
+    mse = mean_squared_error(y, y_pred)
+    rmse = np.sqrt(mse)
+    mae = mean_absolute_error(y, y_pred)
+    r2 = r2_score(y, y_pred)
+    # 调整R²
+    n = len(y)
+    p = x.shape[1]
+    adj_r2 = 1 - (1 - r2) * (n - 1) / (n - p - 1)
+    print("\n===========模型性能指标==================:")
+    print(f"均方误差 (MSE): {mse:.4f}")
+    print(f"均方根误差 (RMSE): {rmse:.4f}")
+    print(f"平均绝对误差 (MAE): {mae:.4f}")
+    print(f"决定系数 (R²): {r2:.4f}")
+    print(f"调整R²: {adj_r2:.4f}")
+
+if __name__ == '__main__':
+    # 对所有结果进行展示
+    # show_all_results()
+    # 按照格式写入txt
+    # save_txt('./tem.txt')
+    # 皮尔逊排序
+    #rank(f'./{config.PROJECT_ID}_rank.csv')
+
+    # 定向绘制皮尔逊分布图
+    # 加载Excel工作簿
+    # workbook = load_workbook(f'./{config.PROJECT_ID}_field_combination.xlsx')
+    # # 获取所有sheet的名称
+    # sheet_names = workbook.sheetnames
+    # print("文件中包含的sheet有:", sheet_names)
+    # # 遍历每一个sheet
+    # for sheet_name in sheet_names:
+    #     sheet = workbook[sheet_name]
+    #     # 获取A列数据(从第一行开始)
+    #     series_a_name = [cell.value for cell in sheet['A'] if cell.value is not None]
+    #     # 获取B列数据(从第一行开始)
+    #     series_b_name = [cell.value for cell in sheet['B'] if cell.value is not None]
+    #     print(f"Sheet名称: {sheet_name}")
+    #     print(f"  A列 {series_a_name}  ")
+    #     print(f"  B列 {series_b_name}  ")
+    #     directed_heatmap(series_a_name, series_b_name)
+
+
+    # 定向自由回归
+    # 加载Excel工作簿
+    workbook = load_workbook(f'./{config.PROJECT_ID}_field_ols.xlsx')
+    # 获取所有sheet的名称
+    sheet_names = workbook.sheetnames
+    print("文件中包含的sheet有:", sheet_names)
+    # 遍历每一个sheet
+    for sheet_name in sheet_names:
+        sheet = workbook[sheet_name]
+        # 获取A列数据(从第一行开始)
+        series_a_name = [cell.value for cell in sheet['A'] if cell.value is not None]
+        # 获取B列数据(从第一行开始)
+        series_b_name = [cell.value for cell in sheet['B'] if cell.value is not None]
+        free_ols(series_b_name[0], series_a_name)
+
+
+
+
+
+
+

+ 102 - 0
config.py

@@ -0,0 +1,102 @@
+import os
+
+# 项目ID
+PROJECT_ID = 1420  # TODO 需要修改项目ID
+# 智慧水萝卜URL登陆用户名和密码
+USERNAME = 'admin'
+PASSWORD = 'JK20200508'
+# 智慧水萝卜URL登陆部门ID参数
+DEP_ID = '135'
+# 智慧水萝卜URL网站地址
+BASE_URL = 'http://120.55.44.4:8900'
+# ALL ITEMS文件保存路径
+ALL_ITEMS_FILE_DIR = r'D:\code\data_analysis\GetItem'
+ALL_ITEMS_FILE_NAME = f'{PROJECT_ID}_all_items.csv' # 输出的文件名称,不要轻易修改
+ALL_ITEMS_FILE_PATH = os.path.join(ALL_ITEMS_FILE_DIR, ALL_ITEMS_FILE_NAME)
+INCLUDE_HEAD = True  # ALL_ITEMS文件是否包含头信息
+TRANSFER_JSON_NAME = f'{PROJECT_ID}_all_items_name_code_transfer.json'
+# 智慧水萝卜某项目最大页数
+MAX_PAGES = 290  # TODO 需要修改为网页的实际数量
+PAGE_SIZE = 20
+
+# 数据库参数,从数据库拿数据时使用
+DB_HOST = '192.168.50.4'
+DB_USER = 'root'
+DB_PASSWORD = '*B-@p2b+97D5xAF1e6'
+DB_NAME = 'ws_data'  # 数据库名称
+DB_PORT = 4000
+POSTFIX = 'hour'  # 应该与DB_SHEET_NAME的后缀保持一致 TODO 可根据需求修改为天、小时或分钟
+DB_SHEET_NAME = f'dc_item_history_data_{POSTFIX}'  # 表名称
+# 开始年月日
+CHECK_YEAR_START = 2025
+CHECK_MONTH_START = 1
+CHECK_DAY_START = 1
+# 结束年月日
+CHECK_YEAR_END = 2025
+CHECK_MONTH_END = 12
+CHECK_DAY_END = 31
+# 时分秒
+CHECK_HOUR_START = 0
+CHECK_MINUTE_START = 0
+CHECK_SECONDS_START = 0
+CHECK_HOUR_END = 23
+CHECK_MINUTE_END = 59
+CHECK_SECONDS_END = 59
+
+# 统计数据量
+STATISTICS_FILE_DIR = r'D:\code\data_analysis\GetItem'
+STATISTICS_FILE_NAME = f'{PROJECT_ID}_statistics_{POSTFIX}.csv'
+# 是否从文件加载
+IS_FROM_LOCAL = False
+
+# 皮尔逊相关系数计算
+EXCLUDE_WORDS = ['电流', '控制字', '步序', '时间设定', '开关', '报警', '噪音']  # 排除列表,若字段包含列表内的词,就自动跳过不参与统计
+DIFF_WORDS = ['累计', '计数', '运行时间', '电能']  # 差分列表,若字段包含列表内的词,就对数据执行差分
+MIN_RECORDS = 2000 # 低于该值的字段不会参与计算  TODO 需要修改记录的最小记录数
+IS_LAG = True
+IS_NORMALIZE = False
+MAX_LAG = 0  # 最大滞后, 0表示没有滞后
+STEP = 1  # 最大滞后步长
+P_VALUE_THRESHOLD = 0.05 # 显著性p值阈值
+PEARSONR_VALUE_THRESHOLD = 0.10 # 皮尔逊相关系数r阈值,低于此阈值视同为0
+# 皮尔逊输出文件保存地址
+R_MAT_JSON_FILE_DIR = r'D:\code\data_analysis\Analysis'
+R_MAT_JSON_FILE_NAME = f'{PROJECT_ID}_pearsonr_mat_{POSTFIX}.pkl'
+R_MAT_JSON_PATH = os.path.join(R_MAT_JSON_FILE_DIR, R_MAT_JSON_FILE_NAME)
+
+
+# Lasso
+LASSO_FEATURE_FILE_DIR = r'D:\code\data_analysis\Analysis'
+LASSO_FEATURE_FILE_NAME = 'lasso_features_choose.json'
+LASSO_FEATURE_FILE_PATH = os.path.join(R_MAT_JSON_FILE_DIR, LASSO_FEATURE_FILE_NAME)
+
+
+# other 拓展参数
+def fmt_date(start_year,end_year,
+             start_month,end_month,
+             start_day,end_day,
+             start_hour=0,end_hour=23,
+             start_minute=0,end_minute=59,
+             start_second=0,end_second=59):
+    fmt = lambda x: '0' + str(x) if abs(x) < 10 else str(x)
+    start_month = fmt(start_month)
+    end_month = fmt(end_month)
+    start_day = fmt(start_day)
+    end_day = fmt(end_day)
+    start_hour = fmt(start_hour)
+    end_hour = fmt(end_hour)
+    start_minute = fmt(start_minute)
+    end_minute = fmt(end_minute)
+    start_second = fmt(start_second)
+    end_second = fmt(end_second)
+
+    start_datetime = f'{start_year}-{start_month}-{start_day} {start_hour}:{start_minute}:{start_second}'
+    end_datetime = f'{end_year}-{end_month}-{end_day} {end_hour}:{end_minute}:{end_second}'
+    return start_datetime, end_datetime
+DATE_START, DATE_END = fmt_date(start_year=CHECK_YEAR_START, start_month=CHECK_MONTH_START, start_day=CHECK_DAY_START,
+                                start_hour=CHECK_HOUR_START, start_minute=CHECK_MINUTE_START, start_second=CHECK_SECONDS_START,
+                                end_year=CHECK_YEAR_END, end_month=CHECK_MONTH_END, end_day=CHECK_DAY_END,
+                                end_hour=CHECK_HOUR_END, end_minute=CHECK_MINUTE_END, end_second=CHECK_SECONDS_END)
+DF_MERGE_FILE_DIR = r'D:\code\data_analysis\Analysis'
+DF_MERGE_FILE_PATH = os.path.join(DF_MERGE_FILE_DIR, f'{PROJECT_ID}_' + DB_NAME + '_' + DB_SHEET_NAME + '_' + POSTFIX + '_' + DATE_START.replace(':','-') + '_' + DATE_END.replace(':','-') + '.pkl').replace(' ', '_')
+

+ 12 - 0
readme

@@ -0,0 +1,12 @@
+第零步,设置config.py文件参数:
+    需要修改水厂id,指定数据库表名称和后缀(day/hour/minutes),修改时间范围等关键参数,看TODO
+第一步,获取水厂所有的字段:
+    运行get_all_items.py,生成all_items.csv,记录了数据库中字段名称和编码,同时生成了名称编码转换词典,all_items_name_code_transfer.json
+第二步,统计数据库数据情况:
+    运行get_items_distribution_from_database.py,生成统计数据statistics.csv,记录了每个字段的数据点数,可以根据这份统计文件修改config的MIN_RECORDS参数
+第三步,为所有字段计算皮尔逊结果:
+    运行pearsonr.py,开始计算全字段皮尔逊系数,计算过程中已经考虑了显著性p值,仅保留显著性结果,pearsonr_mat.pkl为皮尔逊矩阵,ws_data_dc_item_history_data_hour.pkl为历史数据
+第四步,热力图可视化
+    运行show.py
+第五步,挑选相关的变量进行回归分析
+

+ 21 - 0
temp/config_analysis.py

@@ -0,0 +1,21 @@
+PROJECT_ID = 92
+# 需要修改
+BASE_NAME = f'{PROJECT_ID}_all_items_statistics_hour'
+INPUT_CSV_FILE = f'../GetItem/{BASE_NAME}.csv'  # 输入的统计文件,记录了各项字段和记录数
+TOTAL_LIST_JSON_FILE = f'{BASE_NAME}.json'  # 筛选后需要计算相关性的字段
+OUTPUT_JSON_FILE = f'../Analysis/{BASE_NAME}_out.json'  # 输出的计算结果
+COLUMN_NAME_2_INDEX = {'名称': 0, '编码': 1, '单位': 2, '精度': 3, '设备号': 4,'是否枚举':5, '记录数': 6, '最小时间': 7, '最大时间': 8}
+# 需要修改
+DB_SHEET_NAME = 'dc_item_history_data_hour'  # 需要查询的表单名称
+# 需要修改
+DATA_MIN_RECORDS = 400 # 低于这个数量的字段将不会参与统计
+MAX_LAG = 2
+CHECK_YEAR_START = 2025
+CHECK_YEAR_END = 2025
+CHECK_MONTH_START = 6
+CHECK_MONTH_END = 9
+CHECK_DAY_START = 10
+CHECK_DAY_END = 10
+
+P_VALUE_THRESHOLD = 0.05 # 检验t统计的p值阈值
+R_THRESHOLD = 0.35 # 检验t统计的p值阈值,0.6以上可以认为相关程度强,0.8以上可以认为极强

+ 145 - 0
temp/data_show.py

@@ -0,0 +1,145 @@
+#sys.path.append("..")
+import json
+import os
+import csv
+from temp import config_analysis
+from temp.utils_analysis import create_custom_heatmap, set_chinese_font, cross_corr, group_list
+from Database.database_ import Database, DatabaseParam
+
+
+def read_json(json_file, key:str='data'):
+    """加载json"""
+    with open(json_file, 'r', encoding='utf-8') as f:
+        data_ = json.load(f)
+        print('数据加载成功,总数量:', data_.get('len'))
+    return data_.get(key)
+
+def select(d_list: list, l_list:list) -> list:
+
+    counter_dict_row = {}
+    counter_dict_col = {}
+    # 行计数
+    # for l in l_list:
+    #     counter_dict_row.update({l: 0})
+    #     counter_dict_col.update({l: 0})
+    for d in d_list:
+        counter_dict_row.update({d.get('A').get('name'): 0})
+        counter_dict_col.update({d.get('B').get('name'): 0})
+    # 列计数
+    for d in d_list:
+        counter_dict_row[d['A']['name']] += 1
+        counter_dict_col[d['B']['name']] += 1
+
+    # 剔除只自相关的数据字段
+    new_d_list_idx = []
+    for idx, d in enumerate(d_list):
+        if d['A']['name'] == d['B']['name']:
+            if counter_dict_row[d['A']['name']] == 1 and counter_dict_col[d['B']['name']] == 1:
+                continue
+        new_d_list_idx.append(idx)
+    new_d_list = [d_list[i] for i in new_d_list_idx]
+    return new_d_list
+
+
+if __name__ == '__main__':
+    # 添加列表
+    added_list = ['超滤总产水浊度','超滤总产水余氯']
+    # 排除列表
+    not_selected = ['电流', '控制字', '步序', '时间设定', '开关', '报警', '噪音']
+    # 设置中文字体
+    set_chinese_font()
+    # 获取name和code的映射关系
+    with open('../GetItem/92_all_items_name_code_transfer.json', 'r', encoding='utf-8') as f:  # 总字段加载文件
+        name_code_transfer = json.load(f)
+        print(f'加载name与code映射字典,共{name_code_transfer.get('len')}条')
+    name_2_code = name_code_transfer.get('name_2_code')
+    code_2_name = name_code_transfer.get('code_2_name')
+    del name_code_transfer
+    # 结果表格化
+    out_name = 'result_' + os.path.basename(config_analysis.OUTPUT_JSON_FILE)[:-5] + '.csv'
+    data_list = read_json(config_analysis.OUTPUT_JSON_FILE)
+    # 需要统计的所有字段
+    total_name = read_json(config_analysis.OUTPUT_JSON_FILE[:-9] + '.json', 'total_name_list')
+    total_code = read_json(config_analysis.OUTPUT_JSON_FILE[:-9] + '.json', 'total_code_list')
+    data_list = select(data_list, total_name)
+    # 写入csv
+    if os.path.exists(out_name):
+        os.remove(out_name)
+    with open(out_name, 'a', encoding='utf-8') as f:
+        csv.writer(f).writerow(['A序列', 'B序列', 'k1', 'p1', 'r1', 'k2', 'p2', 'r2', 'k3', 'p3', 'r3'])
+        for item_dict in data_list:
+            txt_content_A = [f'{item_dict.get('A').get('name')}({item_dict.get('A').get('code')})']
+            txt_content_B = [f'{item_dict.get('B').get('name')}({item_dict.get('B').get('code')})']
+            txt_content_res = []
+            for it in item_dict.get('res'):
+                txt_content_res.append(f'{it.get('k')}')
+                txt_content_res.append(f'{it.get('p'):.4f}')
+                txt_content_res.append(f'{it.get('r'):.4f}')
+            csv.writer(f).writerow(txt_content_A + txt_content_B + txt_content_res)
+    new_label_name = set()
+    for d in data_list:
+        new_label_name.add(d['A']['name'])
+        new_label_name.add(d['B']['name'])
+    # 增加添加列表
+    for ele in added_list:
+        new_label_name.add(ele)
+    # 剔除排除列表
+    new_label_name_temp = []
+    for ele in new_label_name:
+        flag = True
+        for not_selected_ele in not_selected:
+            if not_selected_ele in ele:
+                flag = False
+                break
+        if flag: new_label_name_temp.append(ele)
+    new_label_name = new_label_name_temp
+    del new_label_name_temp
+    new_label_code = [name_2_code.get(ele) for ele in new_label_name]
+    # 二次筛选后的统计字段
+    del data_list
+    print('筛选后还剩下的字段数:', len(new_label_name))
+
+    # 逐组进行二次计算
+    db_param = DatabaseParam(
+        db_host='192.168.50.4',
+        db_user='root',
+        db_password='*B-@p2b+97D5xAF1e6',
+        db_name='ws_data',
+        db_port=4000)
+    # 按组计算
+    with Database(db_param) as db:
+        group_df = db.query_sql_time_series_group2data_frame(
+                                                 code_name_dict=code_2_name,
+                                                 project_id=config_analysis.PROJECT_ID,
+                                                 sheet_name=config_analysis.DB_SHEET_NAME,
+                                                 data_codes=new_label_code,
+                                                 start_year=config_analysis.CHECK_YEAR_START,
+                                                 end_year=config_analysis.CHECK_YEAR_END,
+                                                 start_month=config_analysis.CHECK_MONTH_START,
+                                                 end_month=config_analysis.CHECK_MONTH_END,
+                                                 start_day=config_analysis.CHECK_DAY_START,
+                                                 end_day=config_analysis.CHECK_DAY_END)
+        del new_label_code, new_label_name
+        new_label_code = sorted([i for i in group_df.columns.tolist() if i != 'time'])
+        # 对列表进行分组
+        row_group_elements_num = 25  # 行分组
+        row_group_code = group_list(new_label_code, row_group_elements_num)
+        col_group_elements_num = 50  # 列分组
+        col_group_code = group_list(new_label_code, col_group_elements_num)
+        for i, row_code in enumerate(row_group_code):
+            for j, col_code in enumerate(col_group_code):
+                corr = cross_corr(row_code, col_code, group_df, code_2_name)
+                create_custom_heatmap(corr_matrix=corr, title=f'中荷水厂数据相关系数热力图{i}-{j}')
+
+    query_name = input('是否继续查询结果所在位置(y/n):')
+    if query_name == 'y' or query_name == 'Y':
+        while query_name != '退出':
+            query_name = input('查询:')
+            flag = False
+            for i, row_group in enumerate(row_group_code):
+                if name_2_code.get(query_name) in row_group:
+                    flag = True
+                    print(f'位置:{i}-*.png')
+                    break
+            if not flag:
+                print(f'位置:{query_name}不在统计范围')

+ 79 - 0
temp/directed_show.py

@@ -0,0 +1,79 @@
+"""
+指定两组数据标签,绘制相关系数热力图,用于定向分析
+"""
+import sys
+sys.path.append("..")
+from temp import config_analysis
+from temp.utils_analysis import create_custom_heatmap, set_chinese_font, cross_corr
+from Database.database_ import Database, DatabaseParam
+import json
+import matplotlib.pyplot as plt
+from datetime import datetime
+time_now = datetime.now().strftime('%H:%M:%S')
+#series_a_name = ['管廊间温度1','加药间温度1','膜车间温度1','泵房温度1']   # 需要修改, 列标签
+#series_b_name = ['RO1 1段酸洗记次数','RO1 1段J洗记次数', 'RO1 1段FY洗记次数', 'RO1 2段FY洗记次数', '1#超滤反洗水泵运行时间', '段间泵A温度', '反渗透高压泵A温度', '反渗透冲洗水泵A温度','反渗透外供水泵A温度','超滤反洗水泵A温度','超滤供水泵A温度','超滤总进水浊度','超滤反洗泵A 累计电量'] # 需要修改, 行标签
+#series_a_name = ['UF1跨膜压差','UF2跨膜压差', 'UF3跨膜压差', 'UF4跨膜压差','超滤总产水PH', 'RO1产水电导','RO2产水电导','RO3产水电导','RO4产水电导','反渗透总产水电导','RO1一段浓水压力','RO2一段浓水压力','RO3一段浓水压力','RO4一段浓水压力','RO1一段进水压力','RO2一段进水压力','RO3一段进水压力','RO4一段进水压力']
+#series_b_name = ['超滤总进水压力','超滤总产水压力','超滤总进水浊度', '反渗透总进水温度','反渗透总进水电导','超滤总产水PH','反渗透总进水PH','外供水-PH','1#阻垢剂加药泵(RO)运行频率','2#阻垢剂加药泵(RO)运行频率','3#阻垢剂加药泵(RO)运行频率', '4#阻垢剂加药泵(RO)运行频率','1#反渗透阻垢剂流量','2#反渗透阻垢剂流量','3#反渗透阻垢剂流量','4#反渗透阻垢剂流量','还原剂流量','还原剂药箱1液位','还原剂药箱2液位','盐酸药箱液位','RO1一段浓水压力','RO2一段浓水压力','RO3一段浓水压力','RO4一段浓水压力','RO1一段进水压力','RO2一段进水压力','RO3一段进水压力','RO4一段进水压力']
+series_a_name = ['超滤总进水压力','RO1产水电导','RO2产水电导','RO3产水电导','RO4产水电导','超滤总产水PH','反渗透总产水电导']
+series_b_name = ['RO1一段浓水压力','RO2一段浓水压力','RO3一段浓水压力','RO4一段浓水压力','RO1一段进水压力','RO2一段进水压力','RO3一段进水压力','RO4一段进水压力','超滤总进水压力','反渗透总进水温度']
+#series_a_name = ['RO1产水电导','RO2产水电导','RO3产水电导','RO4产水电导']
+#series_b_name = ['1#阻垢剂加药泵(RO)运行频率','2#阻垢剂加药泵(RO)运行频率','3#阻垢剂加药泵(RO)运行频率', '4#阻垢剂加药泵(RO)运行频率','1#反渗透阻垢剂流量','2#反渗透阻垢剂流量','3#反渗透阻垢剂流量','4#反渗透阻垢剂流量']
+db_param = DatabaseParam(
+        db_host='192.168.50.4',
+        db_user='root',
+        db_password='*B-@p2b+97D5xAF1e6',
+        db_name='ws_data',
+        db_port=4000)
+# 检查字段
+with open('../GetItem/92_all_items_name_code_transfer.json', 'r', encoding='utf-8') as f:  # 总字段加载文件
+    name_code_transfer = json.load(f)
+    print(f'加载name与code映射字典,共{name_code_transfer.get('len')}条')
+name_2_code = name_code_transfer.get('name_2_code')
+code_2_name = name_code_transfer.get('code_2_name')
+del name_code_transfer
+
+for name in series_a_name:
+    if name not in name_2_code.keys():
+        raise IOError(f'指定字段{name}不存在', '输入A序列:',series_a_name)
+for name in series_b_name:
+    if name not in name_2_code.keys():
+        raise IOError(f'指定字段{name}不存在', '输入B序列:',series_b_name)
+# 设置中文字体
+set_chinese_font()
+# 按组计算
+with Database(db_param) as db:
+    series_a_code = [name_2_code.get(i) for i in series_a_name]
+    series_b_code = [name_2_code.get(i) for i in series_b_name]
+    group_df = db.query_sql_time_series_group2data_frame(
+        code_name_dict=code_2_name,
+        project_id=config_analysis.PROJECT_ID,
+        sheet_name=config_analysis.DB_SHEET_NAME,
+        data_codes=series_a_code + series_b_code,
+        start_year=config_analysis.CHECK_YEAR_START,
+        end_year=config_analysis.CHECK_YEAR_END,
+        start_month=config_analysis.CHECK_MONTH_START,
+        end_month=config_analysis.CHECK_MONTH_END,
+        start_day=config_analysis.CHECK_DAY_START,
+        end_day=config_analysis.CHECK_DAY_END)
+    # 剔除不存在的字段
+    series_a_code = [i for i in series_a_code if i in group_df.columns.tolist()]
+    series_b_code = [i for i in series_b_code if i in group_df.columns.tolist()]
+
+    corr = cross_corr(series_a_code, series_b_code, group_df, code_2_name)
+    create_custom_heatmap(corr_matrix=corr, title=f'中荷水厂数据相关系数热力图 ' + time_now)
+
+    # 绘制曲线
+    for code in series_a_code+series_b_code:
+
+        plt.figure(figsize=(12, 6))
+        # 原始数据
+        plt.plot(group_df['time'], group_df[code], 'b-', alpha=0.3, label='原始数据')
+        plt.title(f'{code_2_name.get(code)} 时间序列')
+        plt.xlabel('时间')
+        plt.ylabel(code_2_name.get(code))
+        plt.legend()
+        plt.grid(True)
+        plt.xticks(rotation=45)  # 旋转x轴标签以便阅读
+        plt.savefig(f'{code_2_name.get(code)}.png', bbox_inches='tight', dpi=300)
+        plt.close()
+        print(f"已保存 {code_2_name.get(code)} 的曲线图到 {code_2_name.get(code)}.png")

+ 189 - 0
temp/main_analysis.py

@@ -0,0 +1,189 @@
+import sys
+sys.path.append("..")
+from Database.database_ import Database, DatabaseParam
+import pandas as pd
+from scipy import stats
+from utils_analysis import label_queue, diff_tool, skip_tool
+import config_analysis
+import os
+import json
+import time
+
+# 打印信息确认
+print(f"""
+查询数据库:ws_data
+查询表:{config_analysis.DB_SHEET_NAME}
+起始日期:{config_analysis.CHECK_YEAR_START}-{config_analysis.CHECK_MONTH_START}-{config_analysis.CHECK_DAY_START}
+终止日期:{config_analysis.CHECK_YEAR_END}-{config_analysis.CHECK_MONTH_END}-{config_analysis.CHECK_DAY_END}
+项目ID:{config_analysis.PROJECT_ID}
+""")
+time.sleep(6)
+# 创建数据库参数
+db_param = DatabaseParam(
+    db_host='192.168.50.4',
+    db_user='root',
+    db_password='*B-@p2b+97D5xAF1e6',
+    db_name='ws_data',
+    db_port=4000)
+
+# 存储总数量
+total_name_list = []
+total_code_list = []
+
+
+# 数据库操作应在内部,Database定义了上下文管理器,负责自动释放连接和游标
+with Database(db_param) as db:
+    # 排除常数序列
+    # 选择从文件加载
+    if os.path.exists(config_analysis.TOTAL_LIST_JSON_FILE):
+        with open(config_analysis.TOTAL_LIST_JSON_FILE, "r", encoding="utf-8") as f:
+            loaded_data = json.load(f)
+            print(f'从文件{config_analysis.TOTAL_LIST_JSON_FILE}中加载待分析列表...')
+        total_name_list = loaded_data['total_name_list']
+        total_code_list = loaded_data['total_code_list']
+    # 文件不存在进行及时分析
+    else:
+        for lab in label_queue():
+            time_series_name = lab.get('name')
+            time_series_code = lab.get('code')
+            df = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
+                                                     sheet_name=config_analysis.DB_SHEET_NAME,
+                                                     data_code=time_series_code,
+                                                     start_year=config_analysis.CHECK_YEAR_START, end_year = config_analysis.CHECK_YEAR_END,
+                                                     start_month=config_analysis.CHECK_MONTH_START, end_month=config_analysis.CHECK_MONTH_END,
+                                                     start_day=config_analysis.CHECK_DAY_START, end_day=config_analysis.CHECK_DAY_END)
+            if df is None:
+                continue
+            # 过滤常数序列
+            if df[df.columns[1]].nunique() <= 2:
+                print(f'过滤常数序列{time_series_name}({time_series_code})!')
+                continue
+            else:
+                total_name_list.append(time_series_name)
+                total_code_list.append(time_series_code)
+        # 保存文件
+        saved_data = {
+            'total_name_list': total_name_list,
+            'total_code_list': total_code_list,
+        }
+
+        with open(config_analysis.TOTAL_LIST_JSON_FILE, "w", encoding="utf-8") as f:
+            json.dump(saved_data, f, ensure_ascii=False, indent=4)
+            print(f'分析列表保存到{config_analysis.TOTAL_LIST_JSON_FILE}')
+    # 存储所有计算结果
+    result = []
+    """
+    result: [dict, dict, ...]
+    dict格式:
+    {
+    'A':{'name':,'code':},
+    'B':{'name':,'code':},
+    'res':[{'k':值,'r':值,'p':值},...]:
+    }
+    """
+    # 寻找需要分析的数据,应该从文件中读取字段
+    # 序列A
+    for a_idx in range(0, len(total_code_list), 1):
+
+        time_series_a_name = total_name_list[a_idx]
+        time_series_a_code = total_code_list[a_idx]
+        # 获取A列
+        df_a = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
+                                                   sheet_name=config_analysis.DB_SHEET_NAME,
+                                                   data_code=time_series_a_code,
+                                                   start_year=config_analysis.CHECK_YEAR_START,
+                                                   end_year=config_analysis.CHECK_YEAR_END,
+                                                   start_month=config_analysis.CHECK_MONTH_START,
+                                                   end_month=config_analysis.CHECK_MONTH_END,
+                                                   start_day=config_analysis.CHECK_DAY_START,
+                                                   end_day=config_analysis.CHECK_DAY_END)
+        if df_a is None:
+            continue
+        # 过滤常数序列
+        if df_a[df_a.columns[1]].nunique() <= 2 :
+            print(f'序列A.{time_series_a_name}({time_series_a_code})遇到常数列, 跳过计算!')
+            continue
+        # 平稳化
+        df_a = diff_tool(time_series_a_name, df_a, df_a.columns[1])
+        # 序列B
+        for b_idx in range(a_idx, len(total_code_list), 1):
+            time_series_b_name = total_name_list[b_idx]
+            time_series_b_code = total_code_list[b_idx]
+            if skip_tool(time_series_a_name, time_series_b_name):
+                print(f'跳过组合:{time_series_a_name} vs. {time_series_b_name}')
+                continue
+            # 获取B列
+            df_b = db.query_sql_time_series2data_frame(project_id=config_analysis.PROJECT_ID,
+                                                       sheet_name=config_analysis.DB_SHEET_NAME,
+                                                       data_code=time_series_b_code,
+                                                       start_year=config_analysis.CHECK_YEAR_START,
+                                                       end_year=config_analysis.CHECK_YEAR_END,
+                                                       start_month=config_analysis.CHECK_MONTH_START,
+                                                       end_month=config_analysis.CHECK_MONTH_END,
+                                                       start_day=config_analysis.CHECK_DAY_START,
+                                                       end_day=config_analysis.CHECK_DAY_END)
+            if df_b is None:
+                continue
+            # if abs(len(df_a) - len(df_b)) > 20: raise ValueError('时序数据数量差异过大:len(A), len(B)', len(df_a),
+            #                                                      len(df_b))
+            # 过滤常数序列,有一些数列为常数,这些数据方差接近0,无法计算协方差
+            if df_b[df_b.columns[1]].nunique() <= 2:
+                print(f'序列B.{time_series_b_name}({time_series_b_code})遇到常数列, 跳过计算!')
+                continue
+            # 平稳化,根据name筛选出需要平稳化的数据,进行一阶差分
+            df_b = diff_tool(time_series_b_name, df_b, df_b.columns[1])
+            # 融合AB序列
+            df_merge = pd.merge(df_a, df_b, how='inner', on='time').sort_values('time', kind='mergesort')
+            _, time_series_a_column, time_series_b_column = df_merge.columns
+
+            # 互相关分析
+            series_a = df_merge[time_series_a_column]
+            series_b = df_merge[time_series_b_column]
+
+            lags = config_analysis.MAX_LAG  # 最大滞后
+            step = 1
+            print(f'正在进行互相关性分析:A.{time_series_a_name}({time_series_a_code}) | B.{time_series_b_name}({time_series_b_code}) ')
+            tem_dict = {'A': {'name': time_series_a_name, 'code': time_series_a_code},
+                        'B': {'name': time_series_b_name, 'code': time_series_b_code},
+                        'res':[]}
+            for lag in range(-lags, lags, step):
+                if lag < 0:  # a滞后于b
+                    series_a_shifted = series_a[-lag:]
+                    series_b_shifted = series_b[:lag]
+                elif lag > 0:  # b滞后于a
+                    series_a_shifted = series_a[:-lag]
+                    series_b_shifted = series_b[lag:]
+                elif lag == 0:  # 0滞后
+                    series_a_shifted = series_a
+                    series_b_shifted = series_b
+                # 计算皮尔逊系数和显著性
+                if len(series_a_shifted) < 24 or len(series_b_shifted) < 24:
+                    print('skip')
+                    continue
+                r, p_value = stats.pearsonr(series_a_shifted, series_b_shifted)
+                # 过滤不显著的数据
+                if p_value > config_analysis.P_VALUE_THRESHOLD:
+                    continue
+                if abs(r) < config_analysis.R_THRESHOLD:
+                    continue
+
+                tem_dict.get('res').append({'k':lag, 'r':r, 'p':p_value})
+                # if lag < 0:
+                #     print(f'A滞后B {abs(lag)}个单位时间, k={lag}, r={r:.4f}, 显著性p={p_value:.4f}')
+                # elif lag > 0:
+                #     print(f'B滞后A {abs(lag)}个单位时间, k={lag}, r={r:.4f}, 显著性p={p_value:.4f}')
+                # else:
+                #     print(f'A与B无滞后, k={lag}, r={r:.4f}, 显著性p={p_value:.6f}')
+            if 0 < len(tem_dict.get('res')): result.append(tem_dict)
+    print(f'计算完成,结果总数量为:{len(result)}')
+
+# 将结果保存到文件
+if os.path.exists(config_analysis.OUTPUT_JSON_FILE):
+    print(f'删除旧文件{config_analysis.OUTPUT_JSON_FILE}')
+    os.remove(config_analysis.OUTPUT_JSON_FILE)
+data = {'data': result, 'len': len(result), 'r_threshold': config_analysis.R_THRESHOLD, 'p_threshold': config_analysis.P_VALUE_THRESHOLD}
+with open(config_analysis.OUTPUT_JSON_FILE, 'w', encoding="utf-8") as f:
+    json.dump(data, f, ensure_ascii=False, indent=4)
+    print(f'数据保存完成,{config_analysis.OUTPUT_JSON_FILE}')
+
+

+ 164 - 0
temp/utils_analysis.py

@@ -0,0 +1,164 @@
+import os
+import sys
+sys.path.append("..")
+import csv
+from temp import config_analysis
+from temp.config_analysis import COLUMN_NAME_2_INDEX as COLUMN_IDX
+import seaborn as sns
+import matplotlib.pyplot as plt
+from  matplotlib import rcParams
+import matplotlib.font_manager as fm
+from scipy import stats
+import numpy as np
+import pandas as pd
+
+def label_queue():
+    """
+    从统计文件中筛选标签,返回标签数据,如果需要修改内部参数请对应修改config_analysis文件
+    """
+    with open(config_analysis.INPUT_CSV_FILE) as csv_file_handler:
+        csv_reader = csv.reader(csv_file_handler)
+        next(csv_reader)  # ['名称', '编码', '单位', '精度', '设备号', '记录数', '最小时间', '最大时间']
+        for row in csv_reader:  # row: list
+            # 通过记录数量筛选
+            if int(row[COLUMN_IDX['记录数']]) < config_analysis.DATA_MIN_RECORDS: continue
+            yield {'name': row[COLUMN_IDX['名称']], 'code': row[COLUMN_IDX['编码']]}
+
+def diff_tool(name:str, frame: pd.DataFrame, col:str):
+    words = ['累计', '计数', '运行时间']
+    for word in words:
+        if word in name:
+            frame[col] = frame[col].diff()
+            frame.dropna(subset=[col], inplace=True)
+    return frame
+
+def skip_tool(series_a_name:str, series_b_name:str):
+    if '温度' in series_a_name and '温度' in series_b_name: return True
+    if '次数' in series_a_name and '次数' in series_b_name: return True
+    if '累计' in series_a_name and '累计' in series_b_name: return True
+    if '电流' in series_a_name and '电流' in series_b_name: return True
+    if '电压' in series_a_name and '电压' in series_b_name: return True
+    if '电流' in series_a_name and '温度' in series_b_name: return True
+    if '温度' in series_a_name and '电流' in series_b_name: return True
+    if '累计电量' in series_a_name and '累计电量' in series_b_name: return True
+    if '运行时间' in series_a_name and '累计电量' in series_b_name: return True
+    if '累计电量' in series_a_name and '运行时间' in series_b_name: return True
+    if '运行时间' in series_a_name and '运行时间' in series_b_name: return True
+    if '时间设定' in series_a_name and '时间设定' in series_b_name: return True
+    return False
+def set_chinese_font():
+    # 1. 清除Matplotlib缓存(关键步骤)
+    # cache_dir = os.path.expanduser('~/.cache/matplotlib')
+    # if os.path.exists(cache_dir):
+    #     print(f"清除Matplotlib缓存: {cache_dir}")
+    #     for file in os.listdir(cache_dir):
+    #         if file.endswith('.cache') or file.endswith('.json'):
+    #             os.remove(os.path.join(cache_dir, file))
+
+    # 2. 列出所有可用中文字体
+    chinese_fonts = [
+        # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',  # 文泉驿微米黑
+        # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc',  # 文泉驿正黑
+        # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc',  # 思源黑体
+        # '/usr/share/fonts/windows/msyh.ttc',  # 微软雅黑
+        '/usr/share/fonts/windows/simsun.ttc'  # 宋体
+    ]
+
+    # 3. 选择第一个可用的中文字体
+    selected_font = None
+    for font_path in chinese_fonts:
+        if os.path.exists(font_path):
+            selected_font = font_path
+            print(f"使用字体: {font_path}")
+            break
+
+    if selected_font is None:
+        print("警告: 未找到任何中文字体文件")
+        # 尝试使用字体名称
+        rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
+        rcParams['axes.unicode_minus'] = False
+    else:
+        # 手动添加字体到字体管理器
+        fm.fontManager.addfont(selected_font)
+        # 获取字体名称
+        font_prop = fm.FontProperties(fname=selected_font)
+        font_name = font_prop.get_name()
+        print(f"字体名称: {font_name}")
+
+        # 设置全局字体
+        rcParams['font.family'] = 'sans-serif'
+        rcParams['font.sans-serif'] = [font_name]
+        rcParams['axes.unicode_minus'] = False
+def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str:
+    # 设置图像尺寸(根据矩阵大小动态调整)
+    size_factor = max(0.5, min(1.5, len(corr_matrix) / 30))  # 缩放因子
+    fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor
+    fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor
+
+    plt.figure(figsize=(fig_width, fig_height))
+
+    # 创建热力图
+    ax = sns.heatmap(
+        corr_matrix,
+        cmap="coolwarm",
+        center=0,
+        annot=True,  # 显示数值
+        fmt=".2f",
+        annot_kws={"size": 13 - len(corr_matrix) / 20},  # 动态调整注释大小
+        linewidths=0.5,
+        linecolor="white",
+        cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"}
+    )
+
+    # 旋转x轴标签
+    plt.xticks(rotation=45, ha='right', fontsize=15)
+    plt.yticks(fontsize=15,rotation=0, ha='right')
+
+    # 设置标题和标签
+    plt.title(title, fontsize=18, pad=20)
+    plt.xlabel(f"B序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15)
+    plt.ylabel(f"A序列影响因素(显著性p值<{config_analysis.P_VALUE_THRESHOLD})", fontsize=15)
+
+    # 添加次要网格线
+    ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5)
+
+    # 调整布局
+    plt.tight_layout()
+
+    # 保存图像
+    output_file = f"{title.replace(' ', '_')}.png"
+    plt.savefig(output_file, dpi=300, bbox_inches='tight')
+    plt.close()
+    print(f"热力图已保存为: {output_file}")
+
+    return output_file
+
+def cross_corr(group_a:list, group_b:list, all_data:pd.DataFrame, code_2_name_dict:dict) -> pd.DataFrame:
+
+    # 创建交叉协方差矩阵
+    corr_matrix = pd.DataFrame(index=group_a, columns=group_b, dtype=np.float32)
+    for a in group_a:
+        for b in group_b:
+            r, p_value = stats.pearsonr(all_data.loc[:, a], all_data.loc[:, b])
+            if p_value < config_analysis.P_VALUE_THRESHOLD:
+                corr_matrix.loc[a, b] = np.float32(r)
+    # 行列标签中文化
+    a_code_2_name = {code: code_2_name_dict.get(code) for code in group_a }
+    b_code_2_name = {code: code_2_name_dict.get(code) for code in group_b }
+    corr_matrix.rename(index=a_code_2_name, columns=b_code_2_name, inplace=True)
+    return corr_matrix
+
+def group_list(data:list, group_elements_num:int) -> list:
+    """对输入的列表元素进行分组"""
+    group_num = len(data) // group_elements_num + 1
+    group_code = []
+    num = 0
+    for g in range(group_num):
+        group_code.append(data[num:num + group_elements_num])
+        num += group_elements_num
+    return group_code
+if __name__ == '__main__':
+    label_q1 = label_queue()
+    label_q2 = label_queue()
+    # for i in label_q1:
+    #     print(i['name'], i['code'])

+ 221 - 0
utils/tools.py

@@ -0,0 +1,221 @@
+import sys
+sys.path.append('..')
+import pandas as pd
+from matplotlib import pyplot as plt
+import seaborn as sns
+import config
+import os
+from  matplotlib import rcParams
+import matplotlib.font_manager as fm
+import numpy as np
+from typing import Iterable
+import json
+
+
+def group_list(data:list, group_elements_num:int) -> list:
+    """对输入的列表元素进行分组,分组数量"""
+    group_num = len(data) // group_elements_num + 1
+    group_code = []
+    num = 0
+    for g in range(group_num):
+        group_code.append(data[num:num + group_elements_num])
+        num += group_elements_num
+    return group_code
+
+
+def fmt_date(start_year,end_year,
+             start_month,end_month,
+             start_day,end_day,
+             start_hour=0,end_hour=23,
+             start_minute=0,end_minute=59,
+             start_second=0,end_second=59):
+    fmt = lambda x: '0' + str(x) if abs(x) < 10 else str(x)
+    start_month = fmt(start_month)
+    end_month = fmt(end_month)
+    start_day = fmt(start_day)
+    end_day = fmt(end_day)
+    start_hour = fmt(start_hour)
+    end_hour = fmt(end_hour)
+    start_minute = fmt(start_minute)
+    end_minute = fmt(end_minute)
+    start_second = fmt(start_second)
+    end_second = fmt(end_second)
+
+    start_datetime = f'{start_year}-{start_month}-{start_day} {start_hour}:{start_minute}:{start_second}'
+    end_datetime = f'{end_year}-{end_month}-{end_day} {end_hour}:{end_minute}:{end_second}'
+    return start_datetime, end_datetime
+
+def create_custom_heatmap(corr_matrix: pd.DataFrame, title:str="相关系数热力图") -> str:
+    """绘制热力图,输入协方差矩阵,自动生成热力图"""
+    corr_matrix.replace(0., np.nan, inplace=True)
+    # 设置图像尺寸(根据矩阵大小动态调整)
+    size_factor = max(0.5, min(1.5, len(corr_matrix) / 30))  # 缩放因子
+    fig_width = 9 + len(corr_matrix.columns) * 0.4 * size_factor
+    fig_height = 7 + len(corr_matrix.index) * 0.4 * size_factor
+
+    plt.figure(figsize=(fig_width, fig_height))
+
+    # 创建热力图
+    ax = sns.heatmap(
+        corr_matrix,
+        cmap="coolwarm",
+        center=0,
+        annot=True,  # 显示数值
+        fmt=".2f",
+        annot_kws={"size": 13 - len(corr_matrix) / 20},  # 动态调整注释大小
+        linewidths=0.5,
+        linecolor="white",
+        cbar_kws={"shrink": 0.8, "label": "皮尔逊相关系数"}
+    )
+
+    # 旋转x轴标签
+    plt.xticks(rotation=45, ha='right', fontsize=15)
+    plt.yticks(fontsize=15,rotation=0, ha='right')
+
+    # 设置标题和标签
+    plt.title(title, fontsize=18, pad=20)
+    plt.xlabel(f"B序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
+    plt.ylabel(f"A序列影响因素(显著性p值<{config.P_VALUE_THRESHOLD})", fontsize=15)
+
+    # 添加次要网格线
+    ax.grid(True, which='minor', color='white', linestyle='-', linewidth=0.5)
+
+    # 调整布局
+    plt.tight_layout()
+
+    # 保存图像
+    output_file = f"{title.replace(' ', '_')}.png"
+    plt.savefig(output_file, dpi=300, bbox_inches='tight')
+    plt.close()
+    print(f"热力图已保存为: {output_file}")
+
+    return output_file
+
+def set_chinese_font():
+    """设置matplotlib中文字体"""
+    # 1. 清除Matplotlib缓存(关键步骤)
+    # cache_dir = os.path.expanduser('~/.cache/matplotlib')
+    # if os.path.exists(cache_dir):
+    #     print(f"清除Matplotlib缓存: {cache_dir}")
+    #     for file in os.listdir(cache_dir):
+    #         if file.endswith('.cache') or file.endswith('.json'):
+    #             os.remove(os.path.join(cache_dir, file))
+
+    # 2. 列出所有可用中文字体
+    chinese_fonts = [
+        # '/usr/share/fonts/truetype/wqy/wqy-microhei.ttc',  # 文泉驿微米黑
+        # '/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc',  # 文泉驿正黑
+        # '/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc',  # 思源黑体
+        # '/usr/share/fonts/windows/msyh.ttc',  # 微软雅黑
+        '/usr/share/fonts/windows/simsun.ttc'  # 宋体
+    ]
+
+    # 3. 选择第一个可用的中文字体
+    selected_font = None
+    for font_path in chinese_fonts:
+        if os.path.exists(font_path):
+            selected_font = font_path
+            print(f"使用字体: {font_path}")
+            break
+
+    if selected_font is None:
+        print("警告: 未找到任何中文字体文件")
+        # 尝试使用字体名称
+        rcParams['font.sans-serif'] = ['WenQuanYi Micro Hei', 'Microsoft YaHei', 'SimSun']
+        rcParams['axes.unicode_minus'] = False
+    else:
+        # 手动添加字体到字体管理器
+        fm.fontManager.addfont(selected_font)
+        # 获取字体名称
+        font_prop = fm.FontProperties(fname=selected_font)
+        font_name = font_prop.get_name()
+        print(f"字体名称: {font_name}")
+
+        # 设置全局字体
+        rcParams['font.family'] = 'sans-serif'
+        rcParams['font.sans-serif'] = [font_name]
+        rcParams['axes.unicode_minus'] = False
+
+def cal_vari_without_zero_nan(data:Iterable, l='None', is_exclude_zero=True)-> tuple:
+    """统计平局值和标准差,0和Nan不参与计算"""
+    if not isinstance(data, pd.Series):
+        raise TypeError("data must be pd.Series")
+    # 计算平均
+    tem_value_list = []
+    for x in data:
+        if (abs(x - 0.) < 1e-6) and is_exclude_zero: continue
+        if pd.isna(x): continue
+        tem_value_list.append(x)
+    arr = np.array(tem_value_list)
+    # 检查是否仍存在nan
+    if np.sum(np.isnan(arr)) > 0:
+        raise ValueError(f'数据{l}中仍存在nan没有被剔除')
+    mean = np.mean(arr, dtype=np.float32)  # 均值
+    std_dev = np.std(arr, dtype=np.float32)  # 标准差
+    return mean, std_dev
+
+def cal_vari_without_nan(data:Iterable, l='None')-> tuple:
+    return cal_vari_without_zero_nan(data, l, is_exclude_zero=False)
+
+def iqr(data:pd.Series)-> None:
+    """剔除序列中的离群点,采用四分位数距法(IQR)"""
+    pass
+    return
+
+def quicksort_part(arr:list, low:int, high:int):
+    """快速排序"""
+    if low >= high:
+        return None
+    # 设定基准值
+    left, right = low, high
+    pivot = abs(arr[low][1])
+    # 右边放大数,左边放小数
+    while left < right:
+        # 先从右面开始向左找小于基准值的数
+        while left < right and abs(arr[right][1]) >= pivot:
+            right -= 1
+        # 执行一次交换
+        if left < right:
+            arr[left], arr[right] = arr[right], arr[left]
+            left += 1
+        # 再从左面开始向右找大于基准值的数
+        while left < right and abs(arr[left][1]) <= pivot:
+            left += 1
+        # 执行一次交换
+        if left < right:
+            arr[left], arr[right] = arr[right], arr[left]
+            right -= 1
+    return left
+
+def quick_sort(arr:list[tuple], low:int, high:int):
+    """元组快排算法"""
+    if low >= high:
+        return
+    # 先排一趟
+    mid = quicksort_part(arr, low, high)
+    # 排左面
+    quick_sort(arr, low, mid-1)
+    # 排右面
+    quick_sort(arr, mid+1, high)
+
+def df_is_symetry(df_mat:pd.DataFrame) -> bool:
+    """检查DataFrame类型的矩阵是否为对称"""
+    if df_mat.shape[0] != df_mat.shape[1]:
+        return False
+
+    # 检查索引和列名是否匹配
+    if not np.array_equal(df_mat.index, df_mat.columns):
+        return False
+
+    # 转换为 NumPy 数组并检查
+    return np.allclose(df_mat.values, df_mat.values.T, rtol=1e-5, atol=1e-08)
+
+def load_transfer_file_name_code(path):
+    if not os.path.exists(path):
+        raise FileNotFoundError('文件未发现:', path)
+    with open(path, "r", encoding="utf-8") as f:
+        json_data = json.load(f)
+    return json_data.get('name_2_code'), json_data.get('code_2_name')
+
+if __name__ == '__main__':
+    pass