added ai.service.ts file
This commit is contained in:
144
src/app/ai.service.ts
Normal file
144
src/app/ai.service.ts
Normal file
@@ -0,0 +1,144 @@
|
||||
import * as tf from "@tensorflow/tfjs";
|
||||
|
||||
export class AiService {
|
||||
|
||||
constructor() {
|
||||
|
||||
}
|
||||
|
||||
private currentModel;
|
||||
flipX(arr) {
|
||||
return [arr.slice(6), arr.slice(3, 6), arr.slice(0, 3)].flat();
|
||||
};
|
||||
|
||||
flipY(arr) {
|
||||
|
||||
|
||||
this.flipX(arr.slice().reverse());
|
||||
}
|
||||
// Creates a 1 hot of the diff
|
||||
showMove(first, second) {
|
||||
let result = [];
|
||||
first.forEach((move, i) => {
|
||||
result.push(Math.abs(move - second[i]));
|
||||
});
|
||||
return result;
|
||||
};
|
||||
|
||||
getMoves (block) {
|
||||
let x = [];
|
||||
let y = [];
|
||||
// Make all the moves
|
||||
for (let i = 0; i < block.length - 1; i++) {
|
||||
const theMove = this.showMove(block[i], block[i + 1]);
|
||||
// Normal move
|
||||
x.push(block[i]);
|
||||
y.push(theMove);
|
||||
// Flipped X move
|
||||
x.push(this.flipX(block[i]));
|
||||
y.push(this.flipX(theMove));
|
||||
// Inverted Move
|
||||
x.push(block[i].slice().reverse());
|
||||
y.push(theMove.slice().reverse());
|
||||
// Flipped Y move
|
||||
x.push(this.flipY(block[i]));
|
||||
y.push(this.flipY(theMove));
|
||||
}
|
||||
return { x, y };
|
||||
};
|
||||
|
||||
constructModel() {
|
||||
this.currentModel && this.currentModel.dispose();
|
||||
tf.disposeVariables();
|
||||
|
||||
const model = tf.sequential();
|
||||
|
||||
model.add(
|
||||
tf.layers.dense({
|
||||
inputShape: [9],
|
||||
units: 64,
|
||||
activation: "relu"
|
||||
})
|
||||
);
|
||||
|
||||
model.add(
|
||||
tf.layers.dense({
|
||||
units: 64,
|
||||
activation: "relu"
|
||||
})
|
||||
);
|
||||
|
||||
model.add(
|
||||
tf.layers.dense({
|
||||
units: 9,
|
||||
activation: "softmax"
|
||||
})
|
||||
);
|
||||
|
||||
const learningRate = 0.005;
|
||||
model.compile({
|
||||
optimizer: tf.train.adam(learningRate),
|
||||
loss: "categoricalCrossentropy",
|
||||
metrics: ["accuracy"]
|
||||
});
|
||||
|
||||
this.currentModel = model;
|
||||
return model;
|
||||
};
|
||||
|
||||
getModel() {
|
||||
if (this.currentModel) {
|
||||
return this.currentModel;
|
||||
} else {
|
||||
return this.constructModel();
|
||||
}
|
||||
};
|
||||
|
||||
async trainOnGames(games, setState) {
|
||||
const model = this.constructModel();
|
||||
// model.dispose();
|
||||
let AllX = [];
|
||||
let AllY = [];
|
||||
|
||||
// console.log("Games in", JSON.stringify(games));
|
||||
games.forEach((game) => {
|
||||
AllX = AllX.concat(game.x);
|
||||
AllY = AllY.concat(game.y);
|
||||
});
|
||||
|
||||
// Tensorfy!
|
||||
const stackedX = tf.stack(AllX);
|
||||
const stackedY = tf.stack(AllY);
|
||||
await this.trainModel(model, stackedX, stackedY);
|
||||
|
||||
// clean up!
|
||||
stackedX.dispose();
|
||||
stackedY.dispose();
|
||||
|
||||
setState(model);
|
||||
// return updatedModel;
|
||||
};
|
||||
|
||||
async trainModel (model, stackedX, stackedY) {
|
||||
const allCallbacks = {
|
||||
// onTrainBegin: log => console.log(log),
|
||||
// onTrainEnd: log => console.log(log),
|
||||
// onEpochBegin: (epoch, log) => console.log(epoch, log),
|
||||
onEpochEnd: (epoch, log) => console.log(epoch, log)
|
||||
// onBatchBegin: (batch, log) => console.log(batch, log),
|
||||
// onBatchEnd: (batch, log) => console.log(batch, log)
|
||||
};
|
||||
|
||||
await model.fit(stackedX, stackedY, {
|
||||
epochs: 100,
|
||||
shuffle: true,
|
||||
batchSize: 32,
|
||||
callbacks: allCallbacks
|
||||
});
|
||||
|
||||
console.log("Model Trained");
|
||||
|
||||
return model;
|
||||
};
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user