Skip to content

Commit

Permalink
almost done
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexStan0 committed Jun 15, 2023
1 parent aa7589e commit 76d7208
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
64 changes: 48 additions & 16 deletions stable-diffusion/aiModel.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,15 @@ class aiModel {
/**
* Uploads files using the Replicate API
* @param {string} filePath paths to the zip files to be uploaded
* @returns {string[]} serving data for the API
* @returns {object} serving data for the API
*/
uploadData = (filePath, instancePrompt, classPrompt) => {

//create a variable that points to the sh script
const scriptPath = 'upload.sh'

//declare an empty array to store the serving urls for the uploaded files
const servingData = [];
//declare an empty object to store the serving urls for the uploaded files
const servingData = {};

//if the file is not a zip file, skip it
if(path.extname(filePath) !== ".zip") throw new Error("Please provide valid data");
Expand All @@ -123,10 +123,10 @@ class aiModel {
//remove the '\n' from the end of the string
scriptOutput = scriptOutput.replace(/^\s+|\s+$/g, '');

//push the data to an array
servingData.push(scriptOutput);
servingData.push(instancePrompt);
servingData.push(classPrompt);
//push the data to an object
servingData.servingUrl = scriptOutput;
servingData.instancePrompt = instancePrompt;
servingData.classPrompt = classPrompt;

return servingData;

Expand All @@ -135,7 +135,10 @@ class aiModel {
/**
* Make a training call to the replicate dreambooth API
* @param {number} maxTrainSteps the number of training iterations
* @param {...string[]} servingData data to create a training call with
* @param {...object} servingData data to create a training call with
* @param {string} [servingData.servingURL] REQUIRED, the url that stores the training data
* @param {string} [servingData.instancePrompt] REQUIRED, prompt with rare 3 character token
* @param {string} [servingData.classPrompt] REQUIRED, prompt that describes the data set
* @returns an array of the training URLS to check the status of the training call with
*/
trainModel(maxTrainSteps = 2000, ...servingData) {
Expand All @@ -146,13 +149,10 @@ class aiModel {
//loop over all the data sets provided and train the model with them
for(const data of servingData) {

//make sure that the data set provided is valid. If it isn't, skip it
if(data.length !== 3) continue;

//store the data inside the array in an appropriate variable
const servingUrl = data[0];
const instancePrompt = data[1];
const classPrompt = data[2];
const servingUrl = data.servingUrl;
const instancePrompt = data.instancePrompt;
const classPrompt = data.classPrompt;

//create an object to store the input parameters unique to each training call
const bodyInput = {
Expand Down Expand Up @@ -201,6 +201,40 @@ class aiModel {

} //end trainModel()

/**
* Checks on the status of training calls
* @param {...string} trainingId
* @returns array of the training IDs
*/
checkTrainingStatus(...trainingId) {

//create an object to store the status of the training calls
const statuses = {};

//loop through all the training id's
for(const id in trainingId) {

//store the location of the training call
const trainingUrl = `https://dreambooth-api-experimental.replicate.com/v1/trainings/${id}`;

//store the header options
const headers = {
'Authorization': `Token ${this.#apiKey}`
}

//fetch the status of the training call through the API
fetch(trainingUrl, headers).then(response => {
if(response.ok) statuses.id = response.status;
})

} //end for-loop

return statuses;

} //end checkTrainingStatus

//TODO: Create a function to generate an image based on the user positive prompt and negative prompt

/**
* Executes a bash command
* @param {string} command what to execute
Expand Down Expand Up @@ -239,8 +273,6 @@ class aiModel {

} //end writeToFile()

//TODO: Create a function to get a prediction (i.e generate an image based on user promt)

} //end trainingModel

export default aiModel;
3 changes: 2 additions & 1 deletion stable-diffusion/start-training.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ let landscape = myModel.uploadData("/mnt/c/Users/alexa/Documents/Data/LandscapeS
console.log(cityscape);
console.log(landscape);

//myModel.trainModel(2000, cityscape, landscape);
let idArr = myModel.trainModel(2000, cityscape, landscape);
myModel.checkTrainingStatus(...idArr);

0 comments on commit 76d7208

Please sign in to comment.