
Implementing a Basic Linear Regression Model in TensorFlow.js
Learn how to create a simple linear regression model using TensorFlow.js, focusing on core concepts such as defining the model, training, and making predictions.
· tutorials · 2 minutes
Building a Linear Regression Model with TensorFlow.js
Linear regression is a fundamental concept in machine learning used to predict a target variable based on a linear relationship with one or more features. TensorFlow.js enables you to implement this directly in JavaScript, allowing for powerful in-browser computations.
Key Steps to Implement Linear Regression
- Define the Model Architecture
- Compile the Model
- Train the Model
- Make Predictions
Example: A Simple Linear Regression Model
Below is a step-by-step guide to creating a basic linear regression model using TensorFlow.js.
import * as tf from '@tensorflow/tfjs';
// Step 1: Define the model architectureconst model = tf.sequential();model.add(tf.layers.dense({ units: 1, inputShape: [1] })); // Single input and output
// Step 2: Compile the modelmodel.compile({ optimizer: tf.train.sgd(0.1), // Stochastic Gradient Descent with a learning rate of 0.1 loss: 'meanSquaredError', // Loss function});
// Step 3: Prepare the training dataconst xs = tf.tensor1d([1, 2, 3, 4]); // Featuresconst ys = tf.tensor1d([1, 3, 5, 7]); // Labels (y = 2x - 1)
// Step 4: Train the model(async () => { await model.fit(xs, ys, { epochs: 100, // Number of training iterations verbose: 0, // Suppress training logs });
// Step 5: Make predictions const prediction = model.predict(tf.tensor1d([5])); prediction.print(); // Outputs: [9], since y = 2(5) - 1})();
Explanation of the Code Model Definition: A sequential model is used, with a single dense layer representing 𝑦=𝑤𝑥+𝑏 y=wx+b. Compilation: The model is compiled with the sgd optimizer and meanSquaredError loss function. Training Data: xs represents the input data, and ys represents the corresponding outputs. Training the Model: The model is trained over 100 epochs to minimize the loss function. Prediction: After training, the model predicts the value of 𝑦 y for a given 𝑥 x.
Visualizing the Results
Visual tools such as tfvis can help monitor the training process and validate the model’s accuracy. For more complex datasets or models, visualization becomes increasingly important
import * as tfvis from '@tensorflow/tfjs-vis';
// Monitor trainingconst history = await model.fit(xs, ys, { epochs: 100, callbacks: tfvis.show.fitCallbacks( { name: 'Training Performance' }, ['loss'], { height: 200, callbacks: ['onEpochEnd'] } ),});
More posts
-
Saving and Restoring Models in TensorFlow.js
Learn how to save and restore TensorFlow.js models with this easy-to-follow tutorial. Perfect for beginners who want to understand how to manage machine learning models.
-
Understanding Tensors in TensorFlow.js
Discover what tensors are in TensorFlow.js, and learn how they differ from traditional data structures like arrays and matrices in the context of machine learning and deep learning.
-
Implementing a Feedforward Neural Network (FNN) in TensorFlow.js
Learn how to build a Feedforward Neural Network (FNN) using TensorFlow.js, focusing on defining the model architecture, training it, and making predictions.