added ai.service.ts file
This commit is contained in:
144
src/app/ai.service.ts
Normal file
144
src/app/ai.service.ts
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
import * as tf from "@tensorflow/tfjs";
|
||||||
|
|
||||||
|
export class AiService {
|
||||||
|
|
||||||
|
constructor() {
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private currentModel;
|
||||||
|
flipX(arr) {
|
||||||
|
return [arr.slice(6), arr.slice(3, 6), arr.slice(0, 3)].flat();
|
||||||
|
};
|
||||||
|
|
||||||
|
flipY(arr) {
|
||||||
|
|
||||||
|
|
||||||
|
this.flipX(arr.slice().reverse());
|
||||||
|
}
|
||||||
|
// Creates a 1 hot of the diff
|
||||||
|
showMove(first, second) {
|
||||||
|
let result = [];
|
||||||
|
first.forEach((move, i) => {
|
||||||
|
result.push(Math.abs(move - second[i]));
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
|
||||||
|
getMoves (block) {
|
||||||
|
let x = [];
|
||||||
|
let y = [];
|
||||||
|
// Make all the moves
|
||||||
|
for (let i = 0; i < block.length - 1; i++) {
|
||||||
|
const theMove = this.showMove(block[i], block[i + 1]);
|
||||||
|
// Normal move
|
||||||
|
x.push(block[i]);
|
||||||
|
y.push(theMove);
|
||||||
|
// Flipped X move
|
||||||
|
x.push(this.flipX(block[i]));
|
||||||
|
y.push(this.flipX(theMove));
|
||||||
|
// Inverted Move
|
||||||
|
x.push(block[i].slice().reverse());
|
||||||
|
y.push(theMove.slice().reverse());
|
||||||
|
// Flipped Y move
|
||||||
|
x.push(this.flipY(block[i]));
|
||||||
|
y.push(this.flipY(theMove));
|
||||||
|
}
|
||||||
|
return { x, y };
|
||||||
|
};
|
||||||
|
|
||||||
|
constructModel() {
|
||||||
|
this.currentModel && this.currentModel.dispose();
|
||||||
|
tf.disposeVariables();
|
||||||
|
|
||||||
|
const model = tf.sequential();
|
||||||
|
|
||||||
|
model.add(
|
||||||
|
tf.layers.dense({
|
||||||
|
inputShape: [9],
|
||||||
|
units: 64,
|
||||||
|
activation: "relu"
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
model.add(
|
||||||
|
tf.layers.dense({
|
||||||
|
units: 64,
|
||||||
|
activation: "relu"
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
model.add(
|
||||||
|
tf.layers.dense({
|
||||||
|
units: 9,
|
||||||
|
activation: "softmax"
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const learningRate = 0.005;
|
||||||
|
model.compile({
|
||||||
|
optimizer: tf.train.adam(learningRate),
|
||||||
|
loss: "categoricalCrossentropy",
|
||||||
|
metrics: ["accuracy"]
|
||||||
|
});
|
||||||
|
|
||||||
|
this.currentModel = model;
|
||||||
|
return model;
|
||||||
|
};
|
||||||
|
|
||||||
|
getModel() {
|
||||||
|
if (this.currentModel) {
|
||||||
|
return this.currentModel;
|
||||||
|
} else {
|
||||||
|
return this.constructModel();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
async trainOnGames(games, setState) {
|
||||||
|
const model = this.constructModel();
|
||||||
|
// model.dispose();
|
||||||
|
let AllX = [];
|
||||||
|
let AllY = [];
|
||||||
|
|
||||||
|
// console.log("Games in", JSON.stringify(games));
|
||||||
|
games.forEach((game) => {
|
||||||
|
AllX = AllX.concat(game.x);
|
||||||
|
AllY = AllY.concat(game.y);
|
||||||
|
});
|
||||||
|
|
||||||
|
// Tensorfy!
|
||||||
|
const stackedX = tf.stack(AllX);
|
||||||
|
const stackedY = tf.stack(AllY);
|
||||||
|
await this.trainModel(model, stackedX, stackedY);
|
||||||
|
|
||||||
|
// clean up!
|
||||||
|
stackedX.dispose();
|
||||||
|
stackedY.dispose();
|
||||||
|
|
||||||
|
setState(model);
|
||||||
|
// return updatedModel;
|
||||||
|
};
|
||||||
|
|
||||||
|
async trainModel (model, stackedX, stackedY) {
|
||||||
|
const allCallbacks = {
|
||||||
|
// onTrainBegin: log => console.log(log),
|
||||||
|
// onTrainEnd: log => console.log(log),
|
||||||
|
// onEpochBegin: (epoch, log) => console.log(epoch, log),
|
||||||
|
onEpochEnd: (epoch, log) => console.log(epoch, log)
|
||||||
|
// onBatchBegin: (batch, log) => console.log(batch, log),
|
||||||
|
// onBatchEnd: (batch, log) => console.log(batch, log)
|
||||||
|
};
|
||||||
|
|
||||||
|
await model.fit(stackedX, stackedY, {
|
||||||
|
epochs: 100,
|
||||||
|
shuffle: true,
|
||||||
|
batchSize: 32,
|
||||||
|
callbacks: allCallbacks
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log("Model Trained");
|
||||||
|
|
||||||
|
return model;
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user