多元线性回归

多个参数影响一个目标是更一般的情况,比如文章的字数,文章的数量,用户的行为数(点击,选择)等可以影响到文章的阅读数。

我参照了网上的代码,用梯度下降法,就是多次(100000次以上)计算导数。导数可能大于1,也可能小于1.但计算次数越多,那么导数越逼近于0,最终未知数w趋近不变化,这时就求出了w的值。

代码:

package net.highersoft.svm.linear_regression.mul_unknown;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Arrays;

public class LinearRegression {
	/*
	 * 训练数据示例: x0 x1 x2 y 1.0 1.0 2.0 7.2 1.0 2.0 1.0 4.9 1.0 3.0 0.0 2.6 1.0 4.0
	 * 1.0 6.3 1.0 5.0 -1.0 1.0 1.0 6.0 0.0 4.7 1.0 7.0 -2.0 -0.6
	 * 注意!!!!x1,x2,y三列是用户实际输入的数据,x0是为了推导出来的公式统一,特地补上的一列。 x0,x1,x2是“特征”,y是结果
	 * 
	 * h(x) = theta0 * x0 + theta1* x1 + theta2 * x2
	 * 
	 * theta0,theta1,theta2 是想要训练出来的参数
	 * 
	 * 此程序采用“梯度下降法”
	 * 
	 * 
	 */
	private double[][] trainData;// 训练数据,一行一个数据,每一行最后一个数据为 y
	private int row;// 训练数据 行数
	private int column;// 训练数据 列数

	private double[] theta;// 参数theta

	private double alpha;// 训练步长
	private int iteration;// 迭代次数

	public static void main(String[] args) {
		//trainData数据: 用户行为数,字数,阅读数
        LinearRegression m = new LinearRegression(LinearRegression.class.getResource("trainData").getFile(),0.00001,100000000);
        
        m.printTrainData();
        m.trainTheta();
        m.printTheta();

	}
	
	public LinearRegression(String fileName) {
		int rowoffile = getRowNumber(fileName);// 获取输入训练数据文本的 行数
		int columnoffile = getColumnNumber(fileName);// 获取输入训练数据文本的 列数

		trainData = new double[rowoffile][columnoffile + 1];// 这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
		this.row = rowoffile;
		this.column = columnoffile + 1;

		this.alpha = 0.001;// 步长默认为0.001
		this.iteration = 100000;// 迭代次数默认为 100000

		theta = new double[column - 1];// h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
		initialize_theta();

		loadTrainDataFromFile(fileName, rowoffile, columnoffile);
	}

	public LinearRegression(String fileName, double alpha, int iteration) {
		int rowoffile = getRowNumber(fileName);// 获取输入训练数据文本的 行数
		int columnoffile = getColumnNumber(fileName);// 获取输入训练数据文本的 列数

		trainData = new double[rowoffile][columnoffile + 1];// 这里需要注意,为什么要+1,因为为了使得公式整齐,我们加了一个特征x0,x0恒等于1
		this.row = rowoffile;
		this.column = columnoffile + 1;

		this.alpha = alpha;
		this.iteration = iteration;
		//未知数w,比训练数据少一列
		theta = new double[column - 1];// h(x)=theta0 * x0 + theta1* x1 + theta2 * x2 + .......
		initialize_theta();

		loadTrainDataFromFile(fileName, rowoffile, columnoffile);
	}

	private int getRowNumber(String fileName) {
		int count = 0;
		File file = new File(fileName);
		BufferedReader reader = null;
		try {
			reader = new BufferedReader(new FileReader(file));
			while (reader.readLine() != null)
				count++;
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (reader != null) {
				try {
					reader.close();
				} catch (IOException e1) {
				}
			}
		}
		return count;

	}

	private int getColumnNumber(String fileName) {
		int count = 0;
		File file = new File(fileName);
		BufferedReader reader = null;
		try {
			reader = new BufferedReader(new FileReader(file));
			String tempString = reader.readLine();
			if (tempString != null)
				count = tempString.split(" ").length;
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (reader != null) {
				try {
					reader.close();
				} catch (IOException e1) {
				}
			}
		}
		return count;
	}

	// 将theta各个参数全部初始化为1.0
	private void initialize_theta(){
		for (int i = 0; i < theta.length; i++) {
			theta[i] = 1.0;
		}
	}

	public void trainTheta() {
		int iteration = this.iteration;
		while ((iteration--) > 0) {
			// 对每个theta i 求 偏导数
			double[] partial_derivative = compute_partial_derivative();// 偏导数
			// 更新每个theta
			for (int i = 0; i < theta.length; i++) {
				theta[i] -= alpha * partial_derivative[i];
			}
			
			int printFlag=iteration+1;
			if(printFlag%1000==0) {
				System.out.println(printFlag+":"+Arrays.toString(theta));
			}
		}
	}

	private double[] compute_partial_derivative() {
		double[] partial_derivative = new double[theta.length];
		// 遍历,对每个theta(每列)求偏导数
		for (int j = 0; j < theta.length; j++) {
			// 第j列
			partial_derivative[j] = compute_partial_derivative_for_theta(j);// 对 theta j 求 偏导
		}
		return partial_derivative;
	}

	private double compute_partial_derivative_for_theta(int j) {
		double sum = 0.0;
		// 遍历 每一行数据
		for (int i = 0; i < row; i++){
			sum += h_theta_x_i_minus_y_i_times_x_j_i(i, j);
		}
		return sum / row;
	}

	/**
	 * 
	 * @param i
	 * @param j  第j列
	 * @return
	 */
	private double h_theta_x_i_minus_y_i_times_x_j_i(int i, int j) {
		double[] oneRow = getRow(i);// 取一行数据,前面是feature,最后一个是y
		double result = 0.0;
		// sum(wx)
		for (int k = 0; k < (oneRow.length - 1); k++)
			result += theta[k] * oneRow[k];
		// ∆y=sum(wx)-y
		result -= oneRow[oneRow.length - 1];
		result *= oneRow[j];
		return result;
	}

	private double[] getRow(int i)// 从训练数据中取出第i行,i=0,1,2,。。。,(row-1)
	{
		return trainData[i];
	}

	private void loadTrainDataFromFile(String fileName, int row, int column) {
		// trainData的第一列全部置为1.0(feature x0)
		for (int i = 0; i < row; i++) {
			trainData[i][0] = 1.0;
		}

		File file = new File(fileName);
		BufferedReader reader = null;
		try {
			reader = new BufferedReader(new FileReader(file));
			String tempString = null;
			int counter = 0;
			while ((counter < row) && (tempString = reader.readLine()) != null) {
				String[] tempData = tempString.split(" ");
				for (int i = 0; i < column; i++) {
					trainData[counter][i + 1] = Double.parseDouble(tempData[i]);
				}
				counter++;
			}
			reader.close();
		} catch (IOException e) {
			e.printStackTrace();
		} finally {
			if (reader != null) {
				try {
					reader.close();
				} catch (IOException e1) {
				}
			}
		}
	}

	public void printTrainData() {
		System.out.println("Train Data:\n");
		for (int i = 0; i < column - 1; i++)
			System.out.printf("%10s", "x" + i + " ");
		System.out.printf("%10s", "y" + " \n");
		for (int i = 0; i < row; i++) {
			for (int j = 0; j < column; j++) {
				System.out.printf("%10s", trainData[i][j] + " ");
			}
			System.out.println();
		}
		System.out.println();
	}

	public void printTheta() {
		for (double a : theta)
			System.out.print(a + " ");
	}
}
训练数据:

41 139.69 857
38 4.22 755
43 9.36 166
51 89.99 173
21 20.96 183
42 30.48 160
12 130.55 143
35 23.45 95
43 15.64 86
14 43.2 94
37 25.88 165
41 118.83 181
45 31.63 151
52 57.42 175
14 251.84 150
55 89.75 155
4 24.64 409
37 45.66 234
54 3.34 74
14 134.35 85
59 6.66 90
39 3.95 100
10 28.06 325
46 48.73 46
48 52.02 139
35 52.51 121
45 3.9 67
58 10.85 408
54 32.81 135
43 29.54 76
15 15.02 88
53 164.42 149
49 29.48 40
26 3.43 42
68 18.84 40
72 155.91 32
20 14.19 36
58 54.99 57
教程:

文/程忠 浏览次数:0次   2021-01-19 22:07:24

相关阅读


评论:
点击刷新

↓ 广告开始-头部带绿为生活 ↓
↑ 广告结束-尾部支持多点击 ↑