Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 77 additions & 25 deletions cli/src/hellaswag_gpt.ts
Original file line number Diff line number Diff line change
@@ -1,50 +1,102 @@
// import fs from 'fs';
import fsPromise from 'node:fs/promises';

import { dirname } from 'path';
import { fileURLToPath } from 'url';
import { parse } from 'ts-command-line-args'

import '@tensorflow/tfjs-node';
import fs from 'node:fs';
import path from 'node:path';
import { Tokenizer, models } from '@epfml/discojs';
import { models, serialization, Tokenizer } from '@epfml/discojs';
import { loadHellaSwag } from '@epfml/discojs-node';
// import { AutoTokenizer } from '@xenova/transformers';

const logFile = path.join('..', 'datasets', 'LogFile_hellaswag.txt');
const logLines: string[] = [];
const __dirname = dirname(fileURLToPath(import.meta.url));

const logLines: string[] = [];
function log(message: string) {
console.log(message);
logLines.push(message);
}

const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(-1)

async function evaluateTFJS(tokenizer: Tokenizer) {
const model = new models.GPT({ seed: 42 });
log('Evaluating TFJS GPT on HellaSwag...');
async function evaluateModel(model: models.GPT | models.ONNXModel, numDataPoints = -1) {
const hellaswagDataset: models.HellaSwagDataset = await loadHellaSwag(numDataPoints)
const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
log('Starting the HellaSwag benchmark...');

const start = Date.now();
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, true);
const duration = ((Date.now() - start) / 1000).toFixed(2);

log(`TFJS GPT Accuracy: ${(accuracy * 100).toFixed(2)}%`);
log(`TFJS GPT Evaluation Time: ${duration} seconds`);
log(`Final accuracy: ${(accuracy * 100).toFixed(2)}%`);
log(`Evaluation Time: ${duration} seconds`);
}

async function evaluateXenova(tokenizer: Tokenizer) {
const model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
log('Evaluating Xenova GPT-2 (ONNX) on HellaSwag...');
const ModelTypes = ['onnx', 'gpt-tfjs-random', 'gpt-tfjs-pretrained'] as const;
type ModelType = typeof ModelTypes[number];

const start = Date.now();
const accuracy = await models.evaluate_hellaswag(model, tokenizer, hellaswagDataset, false);
const duration = ((Date.now() - start) / 1000).toFixed(2);

log(`Xenova GPT-2 Accuracy: ${(accuracy * 100).toFixed(2)}%`);
log(`Xenova GPT-2 Evaluation Time: ${duration} seconds`);
interface HellaSwagArgs {
model: ModelType
numDataPoints: number
logFile: string
pretrainedModelPath: string
help?: boolean
}

async function main(): Promise<void> {
fs.writeFileSync(logFile, '', 'utf-8'); // Clear old log file
const defaultPretrainedModelPath = path.join(__dirname, "..", "..", "onnx-converter", "assets", "model.json")
const args = parse<HellaSwagArgs>({
model: {
type: (raw: string) => raw as ModelType,
description: `Model type, one of ${ModelTypes.toString()}`,
defaultValue: 'onnx'
},
numDataPoints: {
type: Number,
description: 'Number of HellaSwag datapoints to evaluate, set -1 for the whole benchmark',
defaultValue: -1
},
logFile: {
type: String,
description: 'Relative path to the log file, default to ./hellaswag.log', defaultValue: 'hellaswag.log'
},
pretrainedModelPath: {
type: String,
description: 'If specifying gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model',
defaultValue: defaultPretrainedModelPath
},
help: {
type: Boolean,
optional: true,
alias: 'h',
description: 'Prints this usage guide'
}
}, { helpArg: 'help' })

const tokenizer = await Tokenizer.from_pretrained('Xenova/gpt2');
await evaluateTFJS(tokenizer);
log('\n---\n');
await evaluateXenova(tokenizer);
const logFile = path.join(__dirname, args.logFile);
fs.writeFileSync(logFile, '', 'utf-8'); // Clear the log file

let model: models.GPT | models.ONNXModel | undefined;
switch (args.model) {
case 'onnx':
log("Using ONNX pretrained model Xenova/gpt2")
model = await models.ONNXModel.init_pretrained('Xenova/gpt2');
break;
case 'gpt-tfjs-random':
log("Using GPT-TFJS with random initialization")
model = new models.GPT({ seed: 42 });
break;
case 'gpt-tfjs-pretrained':
log("Using GPT-TFJS with pretrained weights")
if (args.pretrainedModelPath === undefined) {
throw new Error("If choosing gpt-tfjs-pretrained, provide the relative path to the TF.js pretrained model `pretrainedModelPath")
}
const encodedModel = await fsPromise.readFile(args.pretrainedModelPath);
model = await serialization.model.decode(encodedModel) as models.GPT;
break;
}
await evaluateModel(model, args.numDataPoints);

fs.writeFileSync(logFile, logLines.join('\n'), 'utf-8');
console.log(`\nResults written to ${logFile}`);
Expand Down
3 changes: 3 additions & 0 deletions datasets/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@

# GDHF demo
/tinder_dog/

# HellaSwag benchmark
hellaswag*
4 changes: 2 additions & 2 deletions discojs/src/default_tasks/cifar10.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ export const cifar10: TaskProvider<"image", "decentralized"> = {
title: 'CIFAR10',
summary: {
preview: 'CIFAR-10 is a classic image classification task, and one of the most widely used datasets for machine learning research.',
overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found <a class='underline text-blue-400' href='https://www.cs.toronto.edu/~kriz/cifar.html' target='_blank'>here</a>. You can find a link to a sample dataset at the next step (Connect Your Data)."
overview: "The dataset contains 60,000 32x32 color images in 10 different classes: airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks. The official CIFAR-10 website can be found at https://www.cs.toronto.edu/~kriz/cifar.html . You can find a link to a sample dataset at the next step."
},
model: 'The model is a pretrained <a class="underline text-blue-400" target="_blank" href="https://github.com/tensorflow/tfjs-models/tree/master/mobilenet">MobileNetV1 model</a> trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
model: 'The model is a pretrained MobileNetV1 model trained in Tensorflow.js. The last output layer is replaced with a fully connected layer with softmax activation and one output neuron per CIFAR10 category. The data preprocessing reshapes images into 224x224 pixels and normalizes values between 0 and 1. The neural network is optimized via Stochastic Gradient Descent and a categorical Cross Entropy loss.',
dataFormatInformation: 'Images should be of .png format and of size 32x32. <br> The CSV file should start with the exact header "filename,label", and each row should contain an image filename (without extension) and its label.<br><br> For example if you have images: 0.png (of a frog) and 1.png (of a car) <br> The CSV file should be: <br>filename, label <br><br> 0, frog <br> 1, car',
dataExample:
"https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/cifar10-example.png",
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/lus_covid.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export const lusCovid: TaskProvider<"image", "federated"> = {
title: 'Lung Ultrasound Image Classification',
summary: {
preview: "Medical images are a typical example of data that exists in huge quantity yet that can't be shared due to confidentiality reasons. Medical applications would immensely benefit from training on data currently locked. More data diversity leads to better generalization and bias mitigation.",
overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. <br>Don't have a dataset of your own? You can find a link to a sample dataset at the next step."
overview: "Disco allows data owners to collaboratively train machine learning models using their respective data without any privacy breach. This example problem is about diagnosing whether patients are positive or negative to COVID-19 from lung ultrasounds images. You can find a link to a sample dataset at the next step."
},
model: "The model is a simple Convolutional Neural Network composed of two convolutional layers with ReLU activations and max pooling layers, followed by a fully connected output layer. The data preprocessing reshapes images into 100x100 pixels and normalizes values between 0 and 1",
dataFormatInformation: 'This model takes as input an image dataset of lung ultrasounds. The images are resized automatically.',
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/mnist.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ export const mnist: TaskProvider<"image", "decentralized"> = {
title: 'Handwritten Digit Recognition',
summary: {
preview: "The MNIST handwritten digit classification problem is a classic dataset used in computer vision and deep learning. The objective is to classify handwritten digits from 28x28 pixel images.",
overview: "Download the classic MNIST dataset of hand-written numbers <a class='underline text-blue-400' target='_blank' href='https://www.kaggle.com/scolianni/mnistasjpg'>here</a>. You can also find a sample dataset at the next step."
overview: "Download the classic MNIST dataset of hand-written numbers at https://www.kaggle.com/scolianni/mnistasjpg . You can also find a sample dataset at the next step."
},
model: "The model is a simple Convolutional Neural Network composed of three convolutional layers with ReLU activations and max pooling layers, followed by two fully connected layers. The data preprocessing simply normalizes values between 0 and 1. The neural network is optimized via RMSProp and a categorical cross-entropy loss.",
dataFormatInformation: 'This model is trained on images corresponding to digits 0 to 9. You can connect your own images of each digit in the box corresponding to its label. The model takes images of size 28x28 as input.',
Expand Down
4 changes: 2 additions & 2 deletions discojs/src/default_tasks/titanic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ export const titanic: TaskProvider<"tabular", "federated"> = {
title: 'Titanic Prediction',
summary: {
preview: "The Titanic classification task is one of the main entrypoints into machine learning. Using passenger data (name, age, gender, socio-economic class, etc), the goal is to identify who was more likely to survive the infamous shipwreck.",
overview: "The original competition can be found on <a target='_blank' class='underline text-blue-400' href='https://www.kaggle.com/c/titanic'>Kaggle</a> and a link to the training set can be found here <a target='_blank' class='underline text-blue-400' href='https://storage.googleapis.com/deai-313515.appspot.com/example_training_data/titanic_train.csv'>here</a>."
overview: ""
},
model: 'The model is a simple 5-layer feedforward network with ReLU activations. The model is optimized with Adam and binary cross-entropy loss. The preprocessing only fills missing value with a placeholder value (0).',
dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc.<br>The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked"<br>Each subsequent row contains passenger data.',
dataFormatInformation: 'The expected format for the tabular dataset is exactly the same as the sample data provided above or in the Kaggle competition. It is a CSV file with 12 columns. The features are general information about the passenger (sex, age, name, etc.) and specific related Titanic data such as the ticket class bought by the passenger, its cabin number, etc. The first line of the CSV contains the header: "PassengerId, Survived, Pclass, Name, Sex, Age, SibSp, Parch, Ticket, Fare, Cabin, Embarked". Each subsequent row contains passenger data.',
dataExample: [
{ name: "PassengerId", data: "1" },
{ name: "Survived", data: "0" },
Expand Down
2 changes: 1 addition & 1 deletion discojs/src/default_tasks/wikitext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export const wikitext: TaskProvider<"text", "federated"> = {
title: "GPT Language Modeling",
summary: {
preview: 'Train a language model (L)LM in your browser, collaboratively and from scratch.',
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling, which you can download <a class='underline text-blue-400' target='_blank' href='https://dax-cdn.cdn.appdomain.cloud/dax-wikitext-103/1.0.1/wikitext-103.tar.gz'>here</a>. More information on how to connect the dataset at the next step."
overview: "You can train a GPT-2 model in your browser and in a collaborative manner on any textual dataset. As an example, you can try the Wikitext-103 dataset, composed of Wikipedia articles, widely used in natural language modeling. More information on how to connect the dataset at the next step."
},
model: [
"The model follows the exact GPT-2 architecture and is implemented in TensorFlow.js.",
Expand Down
10 changes: 4 additions & 6 deletions discojs/src/models/gpt/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@ export type GPTConfig = {
contextLength: number
vocabSize?: number
modelType: GPTModelType
name?: string,
evaluate?: boolean
maxEvalBatches?: number
evaluateEvery?: number
maxIter?: number
weightDecay?: number
verbose?: 0 | 1
debug?: boolean
dropout?: number
attnDrop?: number
residDrop?: number
embdDrop?: number
nLayer?: number
Expand All @@ -30,7 +29,6 @@ export type GPTConfig = {
}
// for a benchmark of performance, see https://github.com/epfml/disco/pull/659
export const DefaultGPTConfig: Required<GPTConfig> = {
name: 'transformer', // prefix for the model layer names
lr: 0.001,
weightDecay: 0,
maxIter: 10,
Expand All @@ -42,9 +40,9 @@ export const DefaultGPTConfig: Required<GPTConfig> = {
contextLength: 128,
vocabSize: 50257,
debug: false,
dropout: 0.2,
residDrop: 0.2,
embdDrop: 0.2,
attnDrop: 0.1,
residDrop: 0.1,
embdDrop: 0.1,
nLayer: 3,
nHead: 3,
nEmbd: 48,
Expand Down
5 changes: 3 additions & 2 deletions discojs/src/models/gpt/layers.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ describe('GPT Layers', () => {
name: 'testCSA',
contextLength: 5,
nHead: 2,
nEmbd: 8, // divisible by nHead, so head size = 4
dropout: 0.0, // no dropout for deterministic tests
nEmbd: 8, // divisible by nHead, so head size = 4
attnDrop: 0.0, // no dropout for deterministic tests
residDrop: 0.0,
nLayer: 2,
seed: 42
};
Expand Down
Loading
Loading